@@ -346,6 +346,9 @@ def __call__(
346
346
generator : Optional [torch .Generator ] = None ,
347
347
use_fp16 : bool = False ,
348
348
image_size : int = 256 ,
349
+ orig_size = (256 , 256 ),
350
+ crop_coords = (0 , 0 ),
351
+ aesthetic_score = 6.0 ,
349
352
):
350
353
from torchvision import transforms
351
354
@@ -366,7 +369,9 @@ def __call__(
366
369
pixel_values = encode_transform (image ).unsqueeze (0 ).to (self .device )
367
370
_ , image_tokens = self .vae .encode (pixel_values )
368
371
mask_token_id = self .transformer .config .mask_token_id
372
+
369
373
image_tokens [mask [None ]] = mask_token_id
374
+
370
375
image_tokens = image_tokens .repeat (num_images_per_prompt , 1 )
371
376
if class_ids is not None :
372
377
if isinstance (class_ids , int ):
@@ -388,7 +393,13 @@ def __call__(
388
393
max_length = self .tokenizer .model_max_length ,
389
394
).input_ids # TODO: remove hardcode
390
395
input_ids = input_ids .to (self .device )
391
- encoder_hidden_states = self .text_encoder (input_ids ).last_hidden_state
396
+
397
+ if self .transformer .config .add_cond_embeds :
398
+ outputs = self .text_encoder (input_ids , return_dict = True , output_hidden_states = True )
399
+ pooled_embeds , encoder_hidden_states = outputs .text_embeds , outputs .hidden_states [- 2 ]
400
+ else :
401
+ encoder_hidden_states = self .text_encoder (input_ids ).last_hidden_state
402
+ pooled_embeds = None
392
403
393
404
if negative_text is not None :
394
405
if isinstance (negative_text , str ):
@@ -417,10 +428,27 @@ def __call__(
417
428
bs_embed * num_images_per_prompt , seq_len , - 1
418
429
)
419
430
431
+ empty_input = self .tokenizer ("" , padding = "max_length" , return_tensors = "pt" ).input_ids .to (
432
+ self .text_encoder .device
433
+ )
434
+ outputs = self .text_encoder (empty_input , output_hidden_states = True )
435
+ empty_embeds = outputs .hidden_states [- 2 ]
436
+ empty_cond_embeds = outputs [0 ]
437
+
420
438
model_inputs = {
421
439
"encoder_hidden_states" : encoder_hidden_states ,
422
440
"negative_embeds" : negative_encoder_hidden_states ,
441
+ "empty_embeds" : empty_embeds ,
442
+ "empty_cond_embeds" : empty_cond_embeds ,
443
+ "cond_embeds" : pooled_embeds ,
423
444
}
445
+
446
+ if self .transformer .config .add_micro_cond_embeds :
447
+ micro_conds = list (orig_size ) + list (crop_coords ) + [aesthetic_score ]
448
+ micro_conds = torch .tensor (micro_conds , device = self .device , dtype = encoder_hidden_states .dtype )
449
+ micro_conds = micro_conds .unsqueeze (0 )
450
+ model_inputs ["micro_conds" ] = micro_conds
451
+
424
452
generate = self .transformer .generate2
425
453
with torch .autocast ("cuda" , enabled = use_fp16 ):
426
454
generated_tokens = generate (
0 commit comments