diff --git a/llava/mm_utils.py b/llava/mm_utils.py index de97345cf..c8dd959f6 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -6,7 +6,7 @@ import ast from transformers import StoppingCriteria -from llava.constants import IMAGE_TOKEN_INDEX +from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN def select_best_resolution(original_size, possible_resolutions): @@ -183,7 +183,7 @@ def process_images(images, image_processor, model_cfg): def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): - prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]