Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

请问在screenspot测试中,数据的platform字段是如何确定的,以及推理中历史信息是否利用了 #18

Open
lijianhong-code opened this issue Jan 5, 2025 · 4 comments
Assignees

Comments

@lijianhong-code
Copy link

lijianhong-code commented Jan 5, 2025

请问在screenspot测试中,数据的platform字段是如何确定的,如原始数据中web,tool来源的数据,如何划分到Mac,Mobile,WIN中,以及推理中历史信息是否利用了?本人复现模型在screenspot上评估结果的指标(70.%)上,与报告所述85.4%相差较大

@lijianhong-code lijianhong-code changed the title 请问在screenspot测试中,数据的平台是如何划分的,以及推理中历史信息是否利用了 请问在screenspot测试中,数据的platform字段是如何确定的,以及推理中历史信息是否利用了 Jan 5, 2025
@jasonnoy
Copy link

jasonnoy commented Jan 6, 2025

在模型训练阶段,是有一定概率不加入platform字段的,在不确定数据采集平台的情况下,可以选择不加入platform字段,或者尝试使用默认平台WIN。历史信息请参考提示词拼接文档“History 字段”部分:https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned。您可以提供更加详细的评测配置,如是否使用模型量化、提示词拼接的具体代码等,以便我们帮助您发现潜在问题。

@lijianhong-code
Copy link
Author

lijianhong-code commented Jan 6, 2025

在模型训练阶段,是有一定概率不加入platform字段的,在不确定数据采集平台的情况下,可以选择不加入platform字段,或者尝试使用默认平台WIN。历史信息请参考提示词拼接文档“History 字段”部分:https://zhipu-ai.feishu.cn/wiki/D9FTwQ78fitS3CkZHUjcKEWTned。您可以提供更加详细的评测配置,如是否使用模型量化、提示词拼接的具体代码等,以便我们帮助您发现潜在问题。

我利用cogagent-9b-20241220参数在benchmark screenspot上进行推理评测,未添加history字段 未使用量化,判断标准为:预测box中心点位于真实标注框内算正确。

@lijianhong-code
Copy link
Author

推理代码如下:def main_ScreenSpot():
"""
A continuous interactive demo using the CogAgent1.5 model with selectable format prompts.
The output_image_path is interpreted as a directory. For each round of interaction,
the annotated image will be saved in the directory with the filename:
{original_image_name_without_extension}_{round_number}.png

Example:
python cli_demo_my.py --model_dir ../cogagent-9b-20241220 --platform "Mac" --max_length 4096 --top_k 1 \
                 --output_image_path ./results --format_key status_action_op_sensitive
"""

parser = argparse.ArgumentParser(
    description="Continuous interactive demo with CogAgent model and selectable format."
)
parser.add_argument(
    "--model_dir", default='../cogagent-9b-20241220',required=True, help="Path or identifier of the model."
)
parser.add_argument(
    "--platform",
    default="Mac",
    help="Platform information string (e.g., 'Mac', 'WIN').",
)
parser.add_argument(
    "--max_length", type=int, default=4096, help="Maximum generation length."
)
parser.add_argument(
    "--top_k", type=int, default=1, help="Top-k sampling parameter."
)
parser.add_argument(
    "--output_image_path",
    default="image_results",
    help="Directory to save the annotated images.",
)
parser.add_argument(
    "--output_pred_path",
    default="./ScreenSpot.csv",
    help="Directory to save the annotated images.",
)
parser.add_argument(
    "--format_key",
    default="status_action_op_sensitive",
    help="Key to select the prompt format.",
)
parser.add_argument(
    "--task_path",
    default="../ScreenSpot/ScreenSpot_combined.json",
    help="Key to select the prompt format.",
)
parser.add_argument(
    "--image_path",
    default="../ScreenSpot/images",
    help="Key to select the prompt format.",
)
args = parser.parse_args()

# Dictionary mapping format keys to format strings
format_dict = {
    "action_op_sensitive": "(Answer in Action-Operation-Sensitive format.)",
    "status_plan_action_op": "(Answer in Status-Plan-Action-Operation format.)",
    "status_action_op_sensitive": "(Answer in Status-Action-Operation-Sensitive format.)",
    "status_action_op": "(Answer in Status-Action-Operation format.)",
    "action_op": "(Answer in Action-Operation format.)",
}

# Ensure the provided format_key is valid
if args.format_key not in format_dict:
    raise ValueError(
        f"Invalid format_key. Available keys are: {list(format_dict.keys())}"
    )

# Ensure the output directory exists
os.makedirs(args.output_image_path, exist_ok=True)

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    args.model_dir,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
    # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # For INT8 quantization
    # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # For INT4 quantization
).eval()
# Initialize platform and selected format strings
platform_str = f"(Platform: {args.platform})\n"
format_str = format_dict[args.format_key]

# Initialize history lists
history_step = []
history_action = []
pre_result_list=[]

round_num = 1

with open(args.task_path, 'r') as file:
    data = json.load(file)

task_id=0
for item in tqdm(data):
    for one in item['annotations']:
        logging.info(f'开始推理{task_id}...')
        pred_result=one.copy()
        task=one['objective_reference']
        img_path=args.image_path+'/'+one['image_id']
        if one["data_source"] in ['ios','android']:
            platform='Mobile'
        elif one["data_source"] in ['macos']:
            platform='Mac'
        else:
            platform='WIN'
        platform_str=f"(Platform: {platform})\n"

        try:
            image = Image.open(img_path).convert("RGB")
        except Exception:
            logging.info("Invalid image path. Please try again.")
            continue

        # Verify history lengths match
        if len(history_step) != len(history_action):
            raise ValueError("Mismatch in lengths of history_step and history_action.")

        # Format history steps for output
        history_str = "\nHistory steps: "
        for index, (step, action) in enumerate(zip(history_step, history_action)):
            history_str += f"\n{index}. {step}\t{action}"

        # Compose the query with task, platform, and selected format instructions
        #query = f"Task: {task}{history_str}\n{platform_str}{format_str}"
        query = f"Task: {task}\n{platform_str}{format_str}"

        logging.info(f"Round {round_num} query:\n{query}")

        inputs = tokenizer.apply_chat_template(
            [{"role": "user", "image": image, "content": query}],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True,
        ).to(model.device)
        # Generation parameters
        gen_kwargs = {
            "max_length": args.max_length,
            "do_sample": True,
            "top_k": args.top_k,
        }

        # Generate response
        with torch.no_grad():
            outputs = model.generate(**inputs, **gen_kwargs)
            outputs = outputs[:, inputs["input_ids"].shape[1]:]
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Extract grounded operation and action
        grounded_pattern = r"Grounded Operation:\s*(.*)"
        action_pattern = r"Action:\s*(.*)"
        matches_history = re.search(grounded_pattern, response)
        matches_actions = re.search(action_pattern, response)

        if matches_history:
            grounded_operation = matches_history.group(1)
            history_step.append(grounded_operation)
        if matches_actions:
            action_operation = matches_actions.group(1)
            history_action.append(action_operation)

        # Extract bounding boxes from the response
        box_pattern = r"box=\[\[?(\d+),(\d+),(\d+),(\d+)\]?\]"
        matches = re.findall(box_pattern, response)

        if matches:
            boxes = [[int(x) / 1000 for x in match] for match in matches]

            # Extract base name of the user's input image (without extension)
            base_name = os.path.splitext(os.path.basename(img_path))[0]
            # Construct the output file name with round number
            output_file_name = f"{base_name}_{round_num}.png"
            output_path = os.path.join(args.output_image_path, output_file_name)

            draw_boxes_on_image(image, boxes, output_path)

            pred_boxes = [[int(x)for x in match] for match in matches]
            logging.info(f"Annotated image saved at: {output_path}")
            pred_result.update({"pred_box": pred_boxes,"platform":platform,"Model_response": response})
        else:
            logging.info("No bounding boxes found in the response.")
            pred_result.update({"pred_box": '', "Model_response": {response}})

        task_id+=1
        round_num += 1
        logging.info(pred_result)
        pre_result_list.append(pred_result)

if os.path.exists(args.output_pred_path):
    # 如果存在,先删除文件
    os.remove(args.output_pred_path)

keys = pre_result_list[0].keys()
with open(args.output_pred_path, 'w', encoding='utf-8', newline='') as output_file:
    dict_writer = csv.DictWriter(output_file, keys)
    dict_writer.writeheader()
    dict_writer.writerows(pre_result_list)

@lijianhong-code
Copy link
Author

lijianhong-code commented Jan 6, 2025

评测代码如下:
def draw_boxes_on_image(img_path, box):
image = Image.open(img_path).convert("RGB")
x_min = int(box[0]/1000 * image.width)
y_min = int(box[1]/1000 * image.height)
x_max = int(box[2]/1000 * image.width)
y_max = int(box[3]/1000 * image.height)
return [x_min, y_min, x_max, y_max]

def evaluate(df):
data_dict_list = []
with open(df, 'r', encoding='GBK') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
data_dict_list.append(row)

true_num=0
corr_num=0
for id,item in enumerate(data_dict_list):
    try:
        pred_box = np.array(eval(item['pred_box'])[0])
        true_box = np.array(eval(item['bounding_box']))
        # 获取预测结果的对应坐标
        img_path=f'./ScreenSpot/images/{item['image_id']}'
        pred_box = draw_boxes_on_image(img_path, pred_box)
        x_min, y_min, x_max, y_max = pred_box
        # 获取标注结果坐标
        true_box = [true_box[0], true_box[1], true_box[0] + true_box[2], true_box[1] + true_box[3]]
        pred_center_x = (x_min + x_max) / 2
        pred_center_y = (y_min + y_max) / 2

        # 判断中心点是否在真实标识框内
        if (true_box[0] <= pred_center_x <= true_box[2]) and (true_box[1] <= pred_center_y <= true_box[3]):
            true_num+=1
        else:
            print(true_box, pred_center_x, pred_center_y)
            print('错误',item['platform'],item['data_source'])
            corr_num+=1
    except:
        continue
print(true_num/len(data_dict_list),true_num,corr_num,len(data_dict_list))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants