Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP implementation for patch level world cover embeddings #231

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
12 changes: 6 additions & 6 deletions scripts/worldcover/embeddings_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from skimage import io

# Set working directory
wd = "/home/usr/Desktop/"
wd = "./"

# To download the existing embeddings run aws s3 sync
# aws s3 sync s3://clay-worldcover-embeddings /my/dir/clay-worldcover-embeddings

vector_dir = Path(wd + "clay-worldcover-embeddings/v002/2021/")
vector_dir = Path(wd + "clay-worldcover-embeddings/2020/")

# Create new DB structure or open existing
db = lancedb.connect(wd + "worldcoverembeddings_db")
Expand All @@ -24,17 +24,17 @@

for _, row in tile_df.iterrows():
data.append(
{"vector": row["embeddings"], "year": 2021, "bbox": row.geometry.bounds}
{"vector": row["embeddings"], "year": 2020, "bbox": row.geometry.bounds}
)

# Show table names
db.table_names()

# Drop existing table if exists
db.drop_table("worldcover-2021-v001")
# db.drop_table("worldcover-2020-v001")

# Create embeddings table and insert the vector data
tbl = db.create_table("worldcover-2021-v001", data=data, mode="overwrite")
tbl = db.create_table("worldcover-2020-v001", data=data, mode="overwrite")


# Visualize some image chips
Expand All @@ -53,6 +53,6 @@ def plot(df, cols=10):


# Select a vector by index, and search 10 similar pairs, and plot
v = tbl.to_pandas()["vector"].values[10540]
v = tbl.to_pandas()["vector"].values[5]
result = tbl.search(query=v).limit(5).to_pandas()
plot(result, 5)
232 changes: 194 additions & 38 deletions scripts/worldcover/run.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
#!/usr/bin/env python3

# import sys
# sys.path.append("/home/tam/Documents/repos/model")
import sys

sys.path.append("../../")

import os
import tempfile
from math import floor
from pathlib import Path

import boto3
import einops
import geopandas as gpd
import numpy
import pyarrow as pa
import pandas as pd
import rasterio
import requests
import shapely
import torch
import xarray as xr
from rasterio.windows import Window
from shapely import box
from torchvision.transforms import v2

from src.datamodule import ClayDataset
Expand All @@ -24,6 +28,7 @@
YEAR = int(os.environ.get("YEAR", 2020))
DATE = f"{YEAR}-06-01"
TILE_SIZE = 12000
PATCH_SIZE = 32
CHIP_SIZE = 512
E_W_INDEX_START = 67
E_W_INDEX_END = 125
Expand All @@ -32,13 +37,16 @@
YORIGIN = 50.0
XORIGIN = -125.0
PXSIZE = 8.333333333333333e-05
SUCCESS_CODE = 200

RASTER_X_SIZE = (E_W_INDEX_END - E_W_INDEX_START) * TILE_SIZE
RASTER_Y_SIZE = (N_S_INDEX_END - N_S_INDEX_START) * TILE_SIZE
NODATA = 0
CKPT_PATH = "s3://clay-model-ckpt/v0/mae_epoch-24_val-loss-0.46.ckpt"
# CKPT_PATH = "https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
VERSION = "002"
CKPT_PATH = (
"https://huggingface.co/made-with-clay/Clay/resolve/main/"
"Clay_v0.1_epoch-24_val-loss-0.46.ckpt"
)
VERSION = "005"
BUCKET = "clay-worldcover-embeddings"
URL = "https://esa-worldcover-s2.s3.amazonaws.com/rgbnir/{year}/N{yidx}/ESA_WorldCover_10m_{year}_v{version}_N{yidx}W{xidx}_S2RGBNIR.tif"
WC_VERSION_LOOKUP = {
Expand Down Expand Up @@ -134,6 +142,63 @@ def tiles_and_windows(input: Window):
return result


def download_image(url):
# Download an image from a URL
response = requests.get(url)
# Check if the request was successful
if response.status_code == SUCCESS_CODE:
return response.content # Return the image content
else:
raise Exception("Failed to download the image")


def patch_bounds_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
# Download an image from a URL
image_data = download_image(url)

# Open the image using rasterio from memory
with rasterio.io.MemoryFile(image_data) as memfile:
with memfile.open() as src:
# Read the image data and metadata
img_data = src.read()
img_meta = src.profile
img_crs = src.crs

# Convert raster data and metadata into an xarray DataArray
img_da = xr.DataArray(img_data, dims=("band", "y", "x"), attrs=img_meta)

# Tile the data
ds_chunked = img_da.chunk({"y": chunk_size[0], "x": chunk_size[1]})

# Get the geospatial information from the original dataset
transform = img_meta["transform"]

# Iterate over the chunks and compute the geospatial bounds for each chunk
chunk_bounds = {}

for x in range(ds_chunked.sizes["x"] // chunk_size[1]):
for y in range(ds_chunked.sizes["y"] // chunk_size[0]):
# Compute chunk coordinates
x_start = x * chunk_size[1]
y_start = y * chunk_size[0]
x_end = min(x_start + chunk_size[1], ds_chunked.sizes["x"])
y_end = min(y_start + chunk_size[0], ds_chunked.sizes["y"])

# Compute chunk geospatial bounds
lon_start, lat_start = transform * (x_start, y_start)
lon_end, lat_end = transform * (x_end, y_end)

# Store chunk bounds
chunk_bounds[(x, y)] = {
"lon_start": lon_start,
"lat_start": lat_start,
"lon_end": lon_end,
"lat_end": lat_end,
}

return chunk_bounds, img_crs


def make_batch(result):
pixels = []
for url, win in result:
Expand Down Expand Up @@ -168,28 +233,67 @@ def make_batch(result):
"timestep": torch.as_tensor(data=[ds.normalize_timestamp(f"{YEAR}-06-01")]).to(
rgb_model.device
),
"date": f"{YEAR}-06-01",
}


def get_pixels(result):
pixels = []
for url, win in result:
with rasterio.open(url) as src:
data = src.read(window=win)
if NODATA in data:
return
pixels.append(data)
# transform = src.window_transform(win)

if len(pixels) == 1:
pixels = pixels[0]
elif len(pixels) == 2: # noqa: PLR2004
if pixels[0].shape[2] == CHIP_SIZE:
pixels = einops.pack(pixels, "b * w")[0]
else:
pixels = einops.pack(pixels, "b h *")[0]
else:
px1 = einops.pack(pixels[:2], "b w *")[0]
px2 = einops.pack(pixels[2:], "b w *")[0]
pixels = einops.pack((px1, px2), "b * w")[0]

assert pixels.shape == (4, CHIP_SIZE, CHIP_SIZE)

return pixels


index = int(os.environ.get("AWS_BATCH_JOB_ARRAY_INDEX", 2))

# Setup model components
tfm = v2.Compose([v2.Normalize(mean=MEAN, std=STD)])
ds = ClayDataset(chips_path=[], transform=tfm)

# Load model
rgb_model = CLAYModule.load_from_checkpoint(
CKPT_PATH,
mask_ratio=0.0,
band_groups={"rgb": (0, 1, 2), "nir": (3,)},
band_groups={"rgb": (2, 1, 0), "nir": (3,)},
bands=4,
strict=False, # ignore the extra parameters in the checkpoint
embeddings_level="group",
)
# Set the model to evaluation mode
rgb_model.eval()

outdir_embeddings = Path("data/embeddings")
outdir_embeddings.mkdir(exist_ok=True, parents=True)

xoff = index * CHIP_SIZE
yoff = 0
embeddings = []
all_bounds = []
results = []
while yoff < RASTER_Y_SIZE:
result = tiles_and_windows(Window(xoff, yoff, CHIP_SIZE, CHIP_SIZE))
if result is not None:
results.append(result)

if result is None:
yoff += CHIP_SIZE
Expand Down Expand Up @@ -219,33 +323,85 @@ def make_batch(result):

yoff += CHIP_SIZE

embeddings = numpy.vstack(embeddings)

embeddings_mean = embeddings[:, :-2, :].mean(axis=1)

print(f"Average embeddings have shape {embeddings_mean.shape}")

gdf = gpd.GeoDataFrame(
data={
"embeddings": pa.FixedShapeTensorArray.from_numpy_ndarray(
numpy.ascontiguousarray(embeddings_mean)
),
},
geometry=[box(*dat) for dat in all_bounds], # This assumes same order
crs="EPSG:4326",
)

with tempfile.TemporaryDirectory() as tmp:
# tmp = "/home/tam/Desktop/wcctmp"

outpath = f"{tmp}/worldcover_embeddings_{YEAR}_{index}_v{VERSION}.gpq"
print(f"Uploading embeddings to {outpath}")

gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")

s3_client = boto3.client("s3")
s3_client.upload_file(
outpath,
BUCKET,
f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}",
)
print(len(embeddings), len(results))
embeddings_ = numpy.vstack(embeddings)
# embeddings_ = embeddings[0]
print("Embeddings shape: ", embeddings_.shape)

# remove date and lat/lon
embeddings_ = embeddings_[:, :-2, :].mean(axis=0)

print(f"Embeddings have shape {embeddings_.shape}")

# reshape to disaggregated patches
embeddings_patch = embeddings_.reshape([2, 16, 16, 768])

# average over the band groups
embeddings_mean = embeddings_patch.mean(axis=0)

print(f"Average patch embeddings have shape {embeddings_mean.shape}")

if result is not None:
print("result: ", result[0][0])
pix = get_pixels(result)
chunk_bounds, epsg = patch_bounds_from_url(result[0][0])
# print("chunk_bounds: ", chunk_bounds)
print("chunk bounds length:", len(chunk_bounds))

# Iterate through each patch
for i in range(embeddings_mean.shape[0]):
for j in range(embeddings_mean.shape[1]):
embeddings_output_patch = embeddings_mean[i, j]

item_ = [
element
for element in list(chunk_bounds.items())
if element[0] == (i, j)
]
box_ = [
item_[0][1]["lon_start"],
item_[0][1]["lat_start"],
item_[0][1]["lon_end"],
item_[0][1]["lat_end"],
]

data = {
"date": pd.to_datetime(batch["date"], format="%Y-%m-%d"),
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
}

# Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
# The box_ list is encoded as
# [bottom left x, bottom left y, top right x, top right y]
box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])

print(str(epsg)[-4:])

# Create the GeoDataFrame
gdf = gpd.GeoDataFrame(
data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}"
)

# Reproject to WGS84 (lon/lat coordinates)
gdf = gdf.to_crs(epsg=4326)

with tempfile.TemporaryDirectory() as tmp:
# tmp = "/home/tam/Desktop/wcctmp"

outpath = (
f"{tmp}/worldcover_patch_embeddings_{YEAR}_{index}_{i}_{j}_"
f"v{VERSION}.gpq"
)
print(f"Uploading embeddings to {outpath}")
# print(gdf)

gdf.to_parquet(
path=outpath, compression="ZSTD", schema_version="1.0.0"
)

s3_client = boto3.client("s3")
s3_client.upload_file(
outpath,
BUCKET,
f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}",
)
Loading