Skip to content

Commit

Permalink
looking at embeddings
Browse files Browse the repository at this point in the history
Made an embedding extractor that works better (ReefCLR_embedding_extractor)
Did the PCA plots and saved these.
Ran random forest.
  • Loading branch information
BenUCL committed Aug 25, 2023
1 parent fb190ca commit a200987
Show file tree
Hide file tree
Showing 8 changed files with 442 additions and 80 deletions.
Empty file.
Empty file.
Binary file not shown.
145 changes: 135 additions & 10 deletions code/notebooks/embedding_extractor/embedding_extractor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"starting_weights = \"/home/ben/reef-audio-representation-learning/code/simclr-pytorch-reefs/logs/exman-train.py/runs/baseline/checkpoint-5100.pth.tar\"\n",
"\n",
"cfg = {'num_classes': 2, 'starting_weights': starting_weights, 'finetune': False,\n",
" 'data_path': '/mnt/ssd-cluster/ben/data/full_dataset/', \n",
" 'json_path': '/home/ben/reef-audio-representation-learning/data/dataset.json'}"
" 'json_path': '/home/ben/reef-audio-representation-learning/data/dataset.json',\n",
" 'test_dataset': 'test_australia'} #######################"
]
},
{
Expand All @@ -53,7 +54,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -67,6 +68,19 @@
" self.convnet.fc = nn.Identity()\n",
" self.classifier = nn.Linear(in_features, num_classes)\n",
"\n",
" def forward(self, x):\n",
" '''\n",
" Forward pass. Here, we define how to apply our model. It's basically\n",
" applying our modified ResNet-18 on the input tensor (\"x\") and then\n",
" apply the final classifier layer on the ResNet-18 output to get our\n",
" num_classes prediction.\n",
" '''\n",
" # x.size(): [B x 3 x W x H]\n",
" features = self.convnet(x) # features.size(): [B x 512 x W x H]\n",
" prediction = self.classifier(features) # prediction.size(): [B x num_classes]\n",
"\n",
" return prediction\n",
"\n",
"# Your function to load pretrained weights\n",
"def load_pretrained_weights(cfg, model):\n",
" custom_weights = cfg['starting_weights']\n",
Expand All @@ -81,12 +95,12 @@
" pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and k not in ['classifier.weight', 'classifier.bias']}\n",
" log = model.load_state_dict(pretrained_dict, strict=False)\n",
" assert log.missing_keys == ['classifier.weight', 'classifier.bias']\n",
" return model\n"
" return model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -102,7 +116,6 @@
],
"source": [
"# Initialize your model\n",
"\n",
"model_instance = SimClrPytorchResNet50(cfg['num_classes'])\n",
"\n",
"# Load the pretrained weights\n",
Expand Down Expand Up @@ -138,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -147,9 +160,23 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 23,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "ValueError",
"evalue": "num_samples should be a positive integer value, but got num_samples=0",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[23], line 11\u001b[0m\n\u001b[1;32m 6\u001b[0m train_percent \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39m0.5\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m 8\u001b[0m dataset \u001b[39m=\u001b[39m CTDataset_train(cfg, split\u001b[39m=\u001b[39msplit, transform\u001b[39m=\u001b[39mtransform, train_percent\u001b[39m=\u001b[39mtrain_percent)\n\u001b[0;32m---> 11\u001b[0m sample_loader \u001b[39m=\u001b[39m DataLoader(dataset, batch_size\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m, shuffle\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
"File \u001b[0;32m~/miniconda3/envs/simclr_pytorch_reefs/lib/python3.8/site-packages/torch/utils/data/dataloader.py:351\u001b[0m, in \u001b[0;36mDataLoader.__init__\u001b[0;34m(self, dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context, generator, prefetch_factor, persistent_workers, pin_memory_device)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[39melse\u001b[39;00m: \u001b[39m# map-style\u001b[39;00m\n\u001b[1;32m 350\u001b[0m \u001b[39mif\u001b[39;00m shuffle:\n\u001b[0;32m--> 351\u001b[0m sampler \u001b[39m=\u001b[39m RandomSampler(dataset, generator\u001b[39m=\u001b[39;49mgenerator) \u001b[39m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 352\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 353\u001b[0m sampler \u001b[39m=\u001b[39m SequentialSampler(dataset) \u001b[39m# type: ignore[arg-type]\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/simclr_pytorch_reefs/lib/python3.8/site-packages/torch/utils/data/sampler.py:107\u001b[0m, in \u001b[0;36mRandomSampler.__init__\u001b[0;34m(self, data_source, replacement, num_samples, generator)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mTypeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mreplacement should be a boolean value, but got \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 104\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mreplacement=\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreplacement))\n\u001b[1;32m 106\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_samples, \u001b[39mint\u001b[39m) \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_samples \u001b[39m<\u001b[39m\u001b[39m=\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[0;32m--> 107\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mnum_samples should be a positive integer \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 108\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mvalue, but got num_samples=\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_samples))\n",
"\u001b[0;31mValueError\u001b[0m: num_samples should be a positive integer value, but got num_samples=0"
]
}
],
"source": [
"dataset = CTDataset(cfg, split='test_data', transform=True) # ditch split at some point\n",
"sample_loader = DataLoader(dataset, batch_size=1, shuffle=True) "
Expand Down Expand Up @@ -216,7 +243,7 @@
" return extract_multiple_embeddings(model, batch, device) # Assuming this function is able to handle batches\n",
"\n",
"# Get 10 embeddings\n",
"for _ in range(2000): # Change the range to get more batches\n",
"for _ in range(10): # Change the range to get more batches\n",
" batch_embeddings = get_next_batch_embeddings(sample_iterator, model_instance, device)\n",
" all_embeddings.append(batch_embeddings)\n",
"\n",
Expand Down Expand Up @@ -359,6 +386,104 @@
"\n",
"print(f\"All embeddings shape: {all_embeddings.shape}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PCA stuff"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from my_custom_dataset_eval import CTDataset_train, CTDataset_test\n",
"\n",
"split = 'test_data'\n",
"transform = False\n",
"train_percent = 0.8\n",
"\n",
"dataset = CTDataset_train(cfg, split=split, transform=transform, train_percent=train_percent)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def get_dataloader(cfg, split, transform, train_percent, batch_size, shuffle, num_workers):\n",
" \n",
" dataset = CTDataset_train(cfg, split, transform, train_percent)\n",
"\n",
" dataloader = DataLoader(\n",
" dataset, \n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" shuffle=shuffle\n",
" )\n",
"\n",
" return dataloader\n",
"\n",
"sample_loader = get_dataloader(cfg, split, transform=False, train_percent = train_percent, batch_size=32, shuffle=False, num_workers=4) \n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"ename": "NotImplementedError",
"evalue": "Module [SimClrPytorchResNet50] is missing the required \"forward\" function",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[13], line 21\u001b[0m\n\u001b[1;32m 18\u001b[0m X \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mto(device) \n\u001b[1;32m 20\u001b[0m \u001b[39m# Forward pass \u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m output \u001b[39m=\u001b[39m model_instance(X)\n\u001b[1;32m 23\u001b[0m \u001b[39m# Extract embeddings\u001b[39;00m\n\u001b[1;32m 24\u001b[0m emb \u001b[39m=\u001b[39m output[\u001b[39m'\u001b[39m\u001b[39membeddings\u001b[39m\u001b[39m'\u001b[39m]\n",
"File \u001b[0;32m~/miniconda3/envs/simclr_pytorch_reefs/lib/python3.8/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m~/miniconda3/envs/simclr_pytorch_reefs/lib/python3.8/site-packages/torch/nn/modules/module.py:363\u001b[0m, in \u001b[0;36m_forward_unimplemented\u001b[0;34m(self, *input)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_forward_unimplemented\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39m\u001b[39minput\u001b[39m: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 353\u001b[0m \u001b[39m \u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\"\"Defines the computation performed at every call.\u001b[39;00m\n\u001b[1;32m 354\u001b[0m \n\u001b[1;32m 355\u001b[0m \u001b[39m Should be overridden by all subclasses.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[39m registered hooks while the latter silently ignores them.\u001b[39;00m\n\u001b[1;32m 362\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mNotImplementedError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mModule [\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m] is missing the required \u001b[39m\u001b[39m\\\"\u001b[39;00m\u001b[39mforward\u001b[39m\u001b[39m\\\"\u001b[39;00m\u001b[39m function\u001b[39m\u001b[39m\"\u001b[39m)\n",
"\u001b[0;31mNotImplementedError\u001b[0m: Module [SimClrPytorchResNet50] is missing the required \"forward\" function"
]
}
],
"source": [
"# Extract embeddings on GPU if available\n",
"if torch.cuda.is_available():\n",
" device = 'cuda'\n",
"else:\n",
" device = 'cpu'\n",
"\n",
"# Put model on device \n",
"model_instance.to(device)\n",
"model_instance.eval()\n",
"\n",
"embeddings = []\n",
"labels = []\n",
"\n",
"\n",
"with torch.no_grad():\n",
" for X, y in sample_loader:\n",
" # Move batch to device\n",
" X = X.to(device) \n",
" \n",
" # Forward pass \n",
" output = model_instance(X)\n",
" \n",
" # Extract embeddings\n",
" emb = output['embeddings']\n",
" \n",
" # Append\n",
" embeddings.append(emb.detach().cpu()) \n",
" labels.append(y.detach().cpu())\n",
" \n",
"# Concatenate \n",
"embeddings = torch.cat(embeddings)\n",
"labels = torch.cat(labels)"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit a200987

Please sign in to comment.