diff --git a/docs/tutorials/Earthquake-preds-Clay.ipynb b/docs/tutorials/Earthquake-preds-Clay.ipynb
new file mode 100644
index 00000000..74caaedb
--- /dev/null
+++ b/docs/tutorials/Earthquake-preds-Clay.ipynb
@@ -0,0 +1,793 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "0e2c4b4b-063f-4b9f-8a6e-df8d48ba4b3a",
+ "metadata": {},
+ "source": [
+ "**Problem Statement**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "26154498-17c4-4aff-94a4-f69fdfabbdde",
+ "metadata": {},
+ "source": [
+ "Earthquake monitoring and damage estimation is essential to ensure that necessary aid reaches the affected areas in atimely manner. However, the coverage of earthquake monitoring network is limited and inadequate for remote areas. Possible solution to this problem can be use of satellite imagery to estimate occurence of earthquake and also get a sense of the extent of damage.\n",
+ "\n",
+ "This notebook presents a workflow to use the embeddings generated by Clay foundation model on Sentinel-1 data to predict whether earthquake has occurred or not. The code below provides a walkthrough of this binary classification problem (detecting earthquake) on Sentinel-1 imagery which is free and available through the Copernicus Sentinel program.\n",
+ "\n",
+ "For more information on Sentinel-1 bands, please see https://developers.google.com/earth-engine/datasets/catalog/COPERNICUS_S1_GRD"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5575d9af-8917-49bb-844b-64356060d72e",
+ "metadata": {},
+ "source": [
+ "**About the dataset**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77210df1-d1e8-4142-833c-0639052c0e1c",
+ "metadata": {},
+ "source": [
+ "*QuakeSet: A Dataset and Low-Resource Models to Monitor Earthquakes through Sentinel-1.*\n",
+ "\n",
+ "Citation:\n",
+ "@misc{cambrin2024quakesetdatasetlowresourcemodels,\n",
+ " title={QuakeSet: A Dataset and Low-Resource Models to Monitor Earthquakes through Sentinel-1}, \n",
+ " author={Daniele Rege Cambrin and Paolo Garza},\n",
+ " year={2024},\n",
+ " eprint={2403.18116},\n",
+ " archivePrefix={arXiv},\n",
+ " primaryClass={cs.CV},\n",
+ " url={https://arxiv.org/abs/2403.18116}, \n",
+ "}\n",
+ "\n",
+ "Huggingface link: https://huggingface.co/datasets/DarthReca/quakeset\n",
+ "\n",
+ "Github link: https://github.com/DarthReca/quakeset/\n",
+ "\n",
+ "torchgeo link: https://torchgeo.readthedocs.io/en/latest/_modules/torchgeo/datasets/quakeset.html#QuakeSet\n",
+ "\n",
+ "**Dataset Structure**\n",
+ "\n",
+ "The dataset is divided into three folds with equal distribution of magnitudes and balanced in positive and negative examples.\n",
+ "\n",
+ "Dataset features:\n",
+ "\n",
+ "-Sentinel-1 SAR imagery\n",
+ "\n",
+ "-before/pre/post imagery of areas affected by earthquakes\n",
+ "\n",
+ "-2 SAR bands (VV/VH)\n",
+ "\n",
+ "-3,327 pairs of pre and post images with 5 m per pixel resolution (512x512 px)\n",
+ "\n",
+ "-2 classification labels (unaffected / affected by earthquake) (**used in this notebook for binary classification**)\n",
+ "\n",
+ "-pre/post image pairs represent earthquake affected areas\n",
+ "\n",
+ "-before/pre image pairs represent hard negative unaffected areas\n",
+ "\n",
+ "-earthquake magnitudes for each sample\n",
+ "\n",
+ "Dataset format:\n",
+ "\n",
+ "single hdf5 dataset containing images, magnitudes, hypercenters, and splits\n",
+ "\n",
+ "Dataset classes:\n",
+ "\n",
+ "-unaffected area\n",
+ "\n",
+ "-earthquake affected area\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf424db3-765c-4f13-953e-4bf180061696",
+ "metadata": {},
+ "source": [
+ "**Approaches to the problem**"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "84c1ddaf-3149-4461-a0bc-91a4b68d99a2",
+ "metadata": {},
+ "source": [
+ "Each sample consist of 2 images (vv and vh band) of pre-event and 2 images (vv and vh band) of post event and the image size is 512 x 512. The ground truth is 1 or 0 indicating whether earthquake has occurred or not. The binary classification problem can be solved using multiple approaches such as:\n",
+ "\n",
+ "1. Using traditional machine learning \n",
+ "The images can be flattened such that each data point has dimension 512 x 512 x 4 with a label of 1 or 0. Then the data can be fed into a traditional ML model such as SVM or RF classifier directly or after the inputs has undergone dimensionality reduction through PCA / t-SNE etc.\n",
+ "\n",
+ "2. Using deep learning \n",
+ "Deep learning algorithms such as CNN models can be trained on imagery treating the input to have 4 channels with labels being 1 or 0. Binary cross entropy loss functions can be used to optimize the CNN network\n",
+ "\n",
+ "3. Using embeddings of pre-trained models \n",
+ "Foundation models have revolutionalized the deep learning space wherein these models are pre-trained on large amount of data and are able to learn a lot of domain knowledge from the data itself. Recently, the application of these models have especially been visible in the NLP space (think LLMs and chatgpt). Clay is a foundation model for earth observation and the embeddings it generates on satellite data (of different platforms) can be used for multiple downstream tasks such as classification, regression etc. since the embeddings contain rich information about the imagery in a latent space. \n",
+ "\n",
+ "Our approach in this notebook is to use the Clay pre-trained model to generate embeddings on the Sentinel-1 imagery (of the Quakeset dataset) and use those embeddings as inputs to random forest classifier for binary classification. Our beleief is that since these embeddings already contain lots of information about the imagery, they should serve as useful features which can be used to do classification.\n",
+ "\n",
+ "This approach is also computationally cheaper and faster since high end GPUs are not needed to train classifier models on embeddings and these models can be trained in few minutes. On the other hand, training large deep learning models require extensive hardware and long training times spanning days and sometime months."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "16dd744e-0f96-47d8-9766-49771e01e71b",
+ "metadata": {},
+ "source": [
+ "**Using Clay embeddings to detect Earthquake - Code walkthrough**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "5bc8c2d1-f2e1-49c9-911b-a4b1fb5efed5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# git clone https://github.com/Clay-foundation/model #Need to run this for the first time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "52d10e8f-ebfb-4b17-8d84-c57af1188f53",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# pip install torchgeo[all]\n",
+ "# pip install git+https://github.com/microsoft/torchgeo.git #Use this to be able to access the QuakeSet dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "89f8bc4b-73ca-47e3-aa07-f444a871ca56",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# importing the required modules\n",
+ "\n",
+ "import torchgeo\n",
+ "import torchgeo.datasets as datasets\n",
+ "import sys\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "import xgboost as xgb\n",
+ "from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "import matplotlib.pyplot as plt\n",
+ "from tqdm.auto import tqdm\n",
+ "from sklearn.metrics import f1_score\n",
+ "\n",
+ "sys.path.append(\"model/\")\n",
+ "\n",
+ "from src.model import ClayMAEModule"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "d87a1df1-9e67-40af-9b76-147e7c3b6a38",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'0.6.0.dev0'"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torchgeo.__version__ # check torchgeo version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7f5ae92-1978-40a1-adcd-56a087f1f67a",
+ "metadata": {},
+ "source": [
+ "**Downloading the QuakeSet dataset**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "57c7327c-e454-4ca6-bb50-fe67272ca3fd",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading https://cdn-lfs-us-1.huggingface.co/repos/67/91/67919aae2184524f91dca1003493d36f4fc4799f55277f03c820fc4c90423eaa/11527e6a21c425b787d0952443e434f53d90a22ee16b75d20e17d54c2b091a78?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27earthquakes.h5%3B+filename%3D%22earthquakes.h5%22%3B&Expires=1723272724&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMzI3MjcyNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzY3LzkxLzY3OTE5YWFlMjE4NDUyNGY5MWRjYTEwMDM0OTNkMzZmNGZjNDc5OWY1NTI3N2YwM2M4MjBmYzRjOTA0MjNlYWEvMTE1MjdlNmEyMWM0MjViNzg3ZDA5NTI0NDNlNDM0ZjUzZDkwYTIyZWUxNmI3NWQyMGUxN2Q1NGMyYjA5MWE3OD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=AAAWpgowEaqyI7lPQs6QTjMYZq03OgAjaKuaW5fLO5gGfj5rbLd566zV5%7EiImjr0ejHSQi1-VIaimbTHd1AG%7E7SBOBqNgLmsbWOYjjTndDN5kWTEBlYggCiBa7msD57x-3ZENB%7ETdMtRx%7EmVdURvgTiWVxOfT3uisCo78EtEL8AycpuU9or6zX0ay17VFc5uPb7wWmU0kpMsG-yF62jo13DqGMScwEFy-OgntW9oPyfOd0NbCo18y7rjtSp1uSbGachkkPzzQcp-0vqX7Z31UjvFHJoom9R38FzcnbJOW0xk%7Ez2YQpcP6ZOwYXSyYZj1CM-7eTI3UPyAia4uMvPyHw__&Key-Pair-Id=K24J24Z295AEI9 to data/earthquakes.h5\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 9666460547/9666460547 [03:50<00:00, 41959334.33it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# We download the QuakeSet dataset available in torchgeo.\n",
+ "\n",
+ "train_ds = datasets.QuakeSet(\n",
+ " split=\"train\", download=True\n",
+ ") # Change download to True to download first time\n",
+ "val_ds = datasets.QuakeSet(split=\"val\", download=True)\n",
+ "test_ds = datasets.QuakeSet(split=\"test\", download=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "c2ebdc0c-d4d8-4ba2-b97d-babb0d0d4d70",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 512, 512])"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Checking the Sentinel-1 imagery size.\n",
+ "\n",
+ "# Each sample consists of pre and post event data with two channels - vv and vh of Sentinel-1\n",
+ "\n",
+ "train_ds[0][\"image\"].shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "df682dd4-7eee-48a7-a42b-3c45cc2a7460",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Downloading: \"https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\" to /teamspace/studios/this_studio/.cache/torch/hub/checkpoints/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\n",
+ "100%|██████████| 1.61G/1.61G [01:49<00:00, 15.7MB/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "78f781a4d10b4a05bffc926451cd3d4b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/343M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/transformer.py:286: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)\n",
+ " warnings.warn(f\"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Download Clay-1 model\n",
+ "\n",
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
+ "ckpt = \"https://clay-model-ckpt.s3.amazonaws.com/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\"\n",
+ "torch.set_default_device(device)\n",
+ "\n",
+ "model = ClayMAEModule.load_from_checkpoint(\n",
+ " ckpt, metadata_path=\"model/configs/metadata.yaml\", shuffle=False, mask_ratio=0\n",
+ ")\n",
+ "model.eval()\n",
+ "\n",
+ "model = model.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "b115e660-fb27-4340-a8e5-09d29c330d1b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Dataset class to use for loading the data\n",
+ "\n",
+ "\n",
+ "class EarthQuakeDataset:\n",
+ " def __init__(self, ds):\n",
+ " self.ds = ds\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.ds)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " pre_image = self.ds[idx][\"image\"][\n",
+ " :2, :, :\n",
+ " ] # First two images are Sentinel-1 images (vv & vh band) of pre-event\n",
+ " post_image = self.ds[idx][\"image\"][\n",
+ " 2:, :, :\n",
+ " ] # Last two images are Sentinel-1 images (vv & vh band) of post-event\n",
+ " label = self.ds[idx][\"label\"]\n",
+ "\n",
+ " sample = {\n",
+ " \"pixels1\": pre_image, # 2 x 512 x 512\n",
+ " \"pixels2\": post_image, # 2 x 512 x 512\n",
+ " \"time\": torch.zeros(4), # Placeholder for time information\n",
+ " \"latlon\": torch.zeros(4), # Placeholder for latlon information\n",
+ " \"label\": label,\n",
+ " }\n",
+ "\n",
+ " return sample\n",
+ "\n",
+ "\n",
+ "# Construct training/validaton/test dataset object\n",
+ "train_dataset = EarthQuakeDataset(train_ds)\n",
+ "validation_dataset = EarthQuakeDataset(val_ds)\n",
+ "test_dataset = EarthQuakeDataset(test_ds)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "bbc4fd0a-0483-4dfd-a82e-a0da1a1530de",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Dataloaders from dataset\n",
+ "BS = 8\n",
+ "\n",
+ "train_dl = torch.utils.data.DataLoader(\n",
+ " train_dataset, batch_size=BS, shuffle=True, generator=torch.Generator(device=device)\n",
+ ")\n",
+ "val_dl = torch.utils.data.DataLoader(\n",
+ " validation_dataset,\n",
+ " batch_size=BS,\n",
+ " shuffle=False,\n",
+ " generator=torch.Generator(device=device),\n",
+ ")\n",
+ "test_dl = torch.utils.data.DataLoader(\n",
+ " test_dataset, batch_size=BS, shuffle=False, generator=torch.Generator(device=device)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "6ac0b263-f334-41ee-88fd-cab4195cb767",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Generate embeddings for the data\n",
+ "\n",
+ "\n",
+ "def generate_and_save_embeddings(dl, split=\"train\"):\n",
+ " gsd = torch.tensor(10, device=device) # Ground sampling distance for Sentinel-1\n",
+ " waves = torch.tensor([3.5, 4.0], device=device) # wavelengths for Sentinel-1\n",
+ "\n",
+ " embeddings1 = []\n",
+ " embeddings2 = []\n",
+ " target = []\n",
+ "\n",
+ " for bid, batch in enumerate(tqdm(dl)):\n",
+ " datacube1 = {\n",
+ " \"pixels\": batch[\"pixels1\"].to(device),\n",
+ " \"time\": batch[\"time\"].to(device),\n",
+ " \"latlon\": batch[\"latlon\"].to(device),\n",
+ " \"gsd\": gsd,\n",
+ " \"waves\": waves,\n",
+ " }\n",
+ "\n",
+ " datacube2 = {\n",
+ " \"pixels\": batch[\"pixels2\"].to(device),\n",
+ " \"time\": batch[\"time\"].to(device),\n",
+ " \"latlon\": batch[\"latlon\"].to(device),\n",
+ " \"gsd\": gsd,\n",
+ " \"waves\": waves,\n",
+ " }\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " unmsk_patch1, unmsk_idx1, msk_idx1, msk_matrix1 = model.model.encoder(\n",
+ " datacube1\n",
+ " )\n",
+ " unmsk_patch2, unmsk_idx2, msk_idx2, msk_matrix2 = model.model.encoder(\n",
+ " datacube2\n",
+ " )\n",
+ "\n",
+ " emb1 = unmsk_patch1[:, 0, :].cpu().numpy()\n",
+ " emb2 = unmsk_patch2[:, 0, :].cpu().numpy()\n",
+ "\n",
+ " embeddings1.append(emb1)\n",
+ " embeddings2.append(emb2)\n",
+ " target.append(batch[\"label\"].cpu().numpy())\n",
+ "\n",
+ " # Saving embeddings and ground truth (label) data\n",
+ " np.save(\n",
+ " f\"{split}_emb.npy\",\n",
+ " np.concatenate(\n",
+ " (np.concatenate(embeddings1), np.concatenate(embeddings2)), axis=1\n",
+ " ),\n",
+ " )\n",
+ " np.save(f\"{split}_label.npy\", np.concatenate(target))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "8a12f703-ffde-4967-9e8a-ecfca9ba5e93",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b03096d21c484c41b35bcbc9d5aedcf6",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/284 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "generate_and_save_embeddings(train_dl, \"train\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ed0bcfdf-fb98-4343-84b9-9d8d5d9813a4",
+ "metadata": {},
+ "source": [
+ "The time taken to generate embeddings on train data was ~48 minutes on 1 T4 GPU with batch size of 8. One can go faster if more powerful GPUs are available"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "144d79bd-833e-41b4-9c61-22fe4c0c0f27",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0329e1101a2047929fd690db9803c03d",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/69 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "generate_and_save_embeddings(val_dl, \"val\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "5d2a0cef-12b3-498b-b696-56ba3e777a48",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "78370eae24fd4011b75bdfc1df90cb2c",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/64 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "generate_and_save_embeddings(test_dl, \"test\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "95b8ffc3-f030-4d83-83eb-c449be524603",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((2266, 1536), (2266,))"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Load the saved embeddings and ground truth (labels)\n",
+ "\n",
+ "train_fea = np.load(\"train_emb.npy\")\n",
+ "train_label = np.load(\"train_label.npy\")\n",
+ "\n",
+ "train_fea.shape, train_label.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "4545418a-6a19-4f70-8665-f5d9f8e09204",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "