Skip to content

Commit

Permalink
restructuring
Browse files Browse the repository at this point in the history
  • Loading branch information
lillythomas committed Apr 26, 2024
1 parent d9e8eb2 commit 7bb7909
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions scripts/worldcover/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
"https://huggingface.co/made-with-clay/Clay/resolve/main/"
"Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
)
# CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
VERSION = "003"
VERSION = "005"
BUCKET = "clay-worldcover-embeddings"
URL = "https://esa-worldcover-s2.s3.amazonaws.com/rgbnir/{year}/N{yidx}/ESA_WorldCover_10m_{year}_v{version}_N{yidx}W{xidx}_S2RGBNIR.tif"
WC_VERSION_LOOKUP = {
Expand Down Expand Up @@ -142,18 +141,17 @@ def tiles_and_windows(input: Window):

return result


def download_image(url):
# Download the image from the URL
# Download an image from a URL
response = requests.get(url)
# Check if the request was successful
if response.status_code == 200:
return response.content # Return the image content
else:
raise Exception("Failed to download the image")

def patches_and_windows_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
# Download the image from the URL
def patch_bounds_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
# Download an image from a URL
image_data = download_image(url)

# Open the image using rasterio from memory
Expand Down Expand Up @@ -198,7 +196,6 @@ def patches_and_windows_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):

return chunk_bounds, img_crs


def make_batch(result):
pixels = []
for url, win in result:
Expand Down Expand Up @@ -282,7 +279,6 @@ def get_pixels(result):
# Set the model to evaluation mode
rgb_model.eval()


outdir_embeddings = Path("data/embeddings")
outdir_embeddings.mkdir(exist_ok=True, parents=True)

Expand Down Expand Up @@ -327,15 +323,16 @@ def get_pixels(result):


print(len(embeddings), len(results))
#embeddings = numpy.vstack(embeddings)
embeddings_ = embeddings[0]
embeddings_ = numpy.vstack(embeddings)
#embeddings_ = embeddings[0]
print("Embeddings shape: ", embeddings_.shape)

# remove date and lat/lon
embeddings_ = embeddings_[:, :-2, :].mean(axis=0)

embeddings_ = embeddings_[:, :-2, :]

print(f"Embeddings have shape {embeddings_.shape}") #.mean(axis=1)
print(f"Embeddings have shape {embeddings_.shape}")

# remove date and lat/lon and reshape to disaggregated patches
# reshape to disaggregated patches
embeddings_patch = embeddings_.reshape([2, 16, 16, 768])

# average over the band groups
Expand All @@ -347,7 +344,7 @@ def get_pixels(result):
if result is not None:
print("result: ", result[0][0])
pix = get_pixels(result)
chunk_bounds, epsg = patches_and_windows_from_url(result[0][0])
chunk_bounds, epsg = patch_bounds_from_url(result[0][0])
#print("chunk_bounds: ", chunk_bounds)
print("chunk bounds length:", len(chunk_bounds))

Expand All @@ -365,16 +362,14 @@ def get_pixels(result):
item_[0][1]["lon_end"],
item_[0][1]["lat_end"],
]
#source_url = batch["source_url"]
date = batch["date"]
date_as_timestamp = pd.to_datetime(date, format="%Y-%m-%d")

# Convert the Pandas Timestamp to the desired data type
#date_as_date32 = date_as_timestamp.astype('datetime64[D]')

#print(batch["date"])
data = {
"date": date_as_timestamp,
#"source_url": batch["source_url"][0],
#"date": pd.to_datetime(arg=date, format="%Y-%m-%d").astype(
# dtype="date32[day][pyarrow]"
#),
#"date": pd.to_datetime(date, format="%Y-%m-%d", dtype="date32[day][pyarrow]"),
"date": pd.to_datetime(batch["date"], format="%Y-%m-%d"),
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
}

Expand All @@ -390,7 +385,6 @@ def get_pixels(result):

# Reproject to WGS84 (lon/lat coordinates)
gdf = gdf.to_crs(epsg=4326)


with tempfile.TemporaryDirectory() as tmp:
# tmp = "/home/tam/Desktop/wcctmp"
Expand Down

0 comments on commit 7bb7909

Please sign in to comment.