Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model with context 2 #31

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/naip-multilabel-contextual.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ dataloader:
metadata_file: /opt/data/california-naip-chips/california-naip-chips-100k.parquet
neighbor_embeddings_folder: /opt/data/california-naip-chips/california-naip-chips-100k-neighbours/npy
neighborhood_radius: 8 # max 8
collapse_osm_cube: True
get_osm_strlabels: True
get_osm_ohearea: True
get_osm_ohecount: True
get_osm_ohelength: True
get_chip_id: True
embeddings_normalization: True
osmvector_normalization: False
# multilabel_threshold_osm_ohecount: 1
Expand All @@ -16,7 +18,7 @@ dataloader:
num_workers: 4

model:
_target_: earthtext.models.multilabel.MultisizeContextualCNN
_target_: earthtext.models.multilabel.MultiscaleContextualCNN
input_dim: 768
output_dim: 140
layers_spec: [512, 256, 128]
Expand Down
314 changes: 179 additions & 135 deletions notebooks/models/04e - multilabel classification w context v1 NAIP.ipynb

Large diffs are not rendered by default.

148 changes: 125 additions & 23 deletions notebooks/naip/00c - neighbours 3D embedding arrays.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"id": "4f20d2d4-af19-49e6-86d2-94aa0d378171",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[38;2;255;117;0m 22%\u001b[39m \u001b[38;2;255;117;0m(23089 of 104234)\u001b[39m |### | Elapsed Time: 0:02:59 ETA: 0:10:32"
"\u001b[38;2;0;255;0m100%\u001b[39m \u001b[38;2;0;255;0m(104234 of 104234)\u001b[39m |################| Elapsed Time: 1:04:41 Time: 1:04:413110\n"
]
}
],
Expand Down Expand Up @@ -297,6 +297,7 @@
"source": [
"folder_grid = \"/opt/data/california-naip-chips/california-naip-chips-100k-neighbours/grid\"\n",
"folder_npy = \"/opt/data/california-naip-chips/california-naip-chips-100k-neighbours/npy\"\n",
"folder_osm = \"/opt/data/california-naip-chips/california-naip-chips-100k-neighbours/osm\"\n",
"file_grid = np.random.choice(os.listdir(folder_grid))\n",
"file_npy = np.random.choice(os.listdir(folder_npy))\n",
"df = pd.read_parquet(f\"{folder_grid}/{file_grid}\")\n",
Expand Down Expand Up @@ -336,18 +337,16 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78e15e82-98c9-4481-b5b7-d2cde0044eed",
"cell_type": "markdown",
"id": "0b6c922b-b3ad-4f1f-89b1-b3c60706fe98",
"metadata": {},
"outputs": [],
"source": [
"# TODO: Validate the OSM cube by checking it the sum over TAGS dimension equals the aggregated OSM"
"Validate the OSM cube by checking it the sum over TAGS dimension equals the aggregated OSM"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"id": "5aa5601d-26bb-47ba-83f6-4f606aa114fe",
"metadata": {},
"outputs": [],
Expand All @@ -360,17 +359,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"id": "33822830-d4db-4991-b7a5-eb85cd46c83b",
"metadata": {},
"outputs": [],
"source": [
"# df_agg = pd.read_parquet(f'{folder}/osm_aggregate.parquet')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 30,
"id": "3727a4d0-fc79-4438-a173-59f609db2404",
"metadata": {
"scrolled": true
Expand All @@ -380,12 +369,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[38;2;0;255;0m100%\u001b[39m \u001b[38;2;0;255;0m(3210 of 3210)\u001b[39m |####################| Elapsed Time: 0:01:39 Time: 0:01:390008\n"
"\u001b[38;2;0;255;0m100%\u001b[39m \u001b[38;2;0;255;0m(1 of 1)\u001b[39m |##########################| Elapsed Time: 0:00:00 Time: 0:00:00\n"
]
}
],
"source": [
"for i in pbar(range(len(df_agg))):\n",
"# CAUTION before you run: takes an hour\n",
"# for i in pbar(range(len(df_agg))):\n",
"for i in pbar(range(1)):\n",
" row = df_agg.iloc[i]\n",
" z = pd.read_parquet(f\"{folder}/{row.name}.parquet\")\n",
" _osm = _df.loc[z.chipid[z.chipid.isin(_df.index)]].sum()\n",
Expand All @@ -394,12 +385,123 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 31,
"id": "6bf25a90-ba76-4f54-9d4e-4134112b97c9",
"metadata": {},
"outputs": [],
"source": [
"df_agg.to_parquet(f'{folder}/osm_aggregate.parquet')"
"## CAUTION: will overwrite the existing .parquet file\n",
"# df_agg.to_parquet(f'{folder}/osm_aggregate.parquet')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "9eec7ac2-7fc1-4510-94fa-9ab74bfe6533",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ca_m_3411730_sw_11_060_20220501-36-11\n"
]
},
{
"data": {
"text/plain": [
"onehot_count [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0,...\n",
"onehot_area [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
"onehot_length [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
"Name: ca_m_3411730_sw_11_060_20220501-36-11, dtype: object"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(row.name)\n",
"df_agg.loc[row.name]"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "d0ddd3b8-6ca4-4680-9ed6-6d925189f325",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(140, 3, 17, 17)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = np.load(f'{folder}/osm/{row.name}.osm.npy')\n",
"x.shape"
]
},
{
"cell_type": "markdown",
"id": "fdb4d999-8144-4821-8176-1eb173cdc4af",
"metadata": {},
"source": [
"Compare. NOTE: `df_agg` aggregates osm data from ALL neighbors in the circle, beyond the 17x17 grid.\n",
"\n",
"But they look similar enough for the purposes of this sanity check."
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "c257178e-4afb-4f77-b537-13d54f507503",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0 0 0 0 0 0 0 0 0 0 0 0 0 13 0 0 0 0 0 5 0 0 0 0\n",
" 0 8 0 0 0 0 0 0 0 0 56 0 0 0 0 49 7 0 0 28 0 0 4 0\n",
" 0 0 0 0 2 0 0 0 0 0 0 0 0 22 0 0 0 0 0 0 0 0 0 0\n",
" 15 0 0 0 0 12 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 24 0 24 9 0 3 6]\n"
]
}
],
"source": [
"print(df_agg.loc[row.name]['onehot_count'])"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "d696c4b2-0edd-417a-8fd9-282f307ca05c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0 0 0 0 0 0 0 0 0 0 0 0 0 8 0 0 0 0 0 3 0 0 0 0\n",
" 0 5 0 0 0 0 0 0 0 0 25 0 0 0 0 23 2 0 0 14 0 0 4 0\n",
" 0 0 0 0 1 0 0 0 0 0 0 0 0 9 0 0 0 0 0 0 0 0 0 0\n",
" 8 0 0 0 0 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 3]\n"
]
}
],
"source": [
"print(x.sum(axis=(2,3))[:, 0].astype(int))"
]
}
],
Expand Down
Loading