From b16087ad73adc659fe2bcf5d87269061c10cc9e5 Mon Sep 17 00:00:00 2001 From: Tiange Luo Date: Wed, 6 Dec 2023 21:37:15 -0500 Subject: [PATCH] integrate_Cap3D_captions --- objaverse/__init__.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/objaverse/__init__.py b/objaverse/__init__.py index 6b79acf..0e334ac 100644 --- a/objaverse/__init__.py +++ b/objaverse/__init__.py @@ -47,9 +47,25 @@ def load_annotations(uids: Optional[List[str]] = None) -> Dict[str, Any]: urllib.request.urlretrieve(hf_url, local_path) with gzip.open(local_path, "rb") as f: data = json.load(f) + local_cap3d_path = os.path.join(metadata_path, "cap3d_captions.json.gz") + if not os.path.exists(local_cap3d_path): + hf_url = "https://huggingface.co/datasets/tiange/Cap3D/resolve/main/Objaverse_files/cap3d_captions.json.gz" + # wget the caption file and put it in local_path + os.makedirs(os.path.dirname(local_cap3d_path), exist_ok=True) + urllib.request.urlretrieve(hf_url, local_cap3d_path) + with gzip.open(local_cap3d_path, "rt", encoding='UTF-8') as f: + captions = json.load(f) if uids is not None: - data = {uid: data[uid] for uid in uids if uid in data} - out.update(data) + cur_data = {} + for uid in uids: + if uid in data: + tmp_data = data[uid] + if uid in captions.keys(): + tmp_data["cap3d"] ={'caption': captions[uid], 'pointcloud_download_url':'https://huggingface.co/datasets/tiange/Cap3D/tree/main/PointCloud_zips', 'paper':'Scalable 3D Captioning with Pretrained Models'} + else: + tmp_data["cap3d"] = '' + cur_data[uid] = tmp_data + out.update(cur_data) if uids is not None and len(out) == len(uids): break return out