Skip to content

Commit

Permalink
update model training
Browse files Browse the repository at this point in the history
  • Loading branch information
Muyiyuan committed May 27, 2023
1 parent 200bb75 commit 95a72fd
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 43 deletions.
49 changes: 27 additions & 22 deletions ABEdeepoff.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,19 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 4,
"id": "3fcca78f-dc04-4896-8430-c8d5b91e5f95",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zcd/miniconda3/envs/pytorch/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
]
}
],
"source": [
"import pkbar\n",
"import torch\n",
Expand Down Expand Up @@ -262,12 +271,12 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 9,
"id": "addd6c5c-bd95-4c04-a258-9e4784c268cb",
"metadata": {},
"outputs": [],
"source": [
"def train_model(train_iter, valid_iter, test_iter, patience, k, j, version=None):\n",
"def train_model(train_iter, valid_iter, test_iter, patience, k, version=None):\n",
" global my_optim\n",
" global lr_dict\n",
" global lst_testing\n",
Expand Down Expand Up @@ -353,7 +362,7 @@
" valid_loss = epoch_loss_ / len(valid_iter)\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" p = f'{path}/ABE_RNN_{k}_{j}_{round(valid_loss, 6)}.pt'\n",
" p = f'{path}/ABE_RNN_{k}_{round(valid_loss, 6)}.pt'\n",
" torch.save(model.state_dict(), p)\n",
" logger.info(f'----Testing------,{p}')\n",
" r = get_test(test_iter, model)\n",
Expand Down Expand Up @@ -431,7 +440,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 10,
"id": "d3616bac-e1ae-470d-8fea-198ef0d74f6c",
"metadata": {},
"outputs": [],
Expand All @@ -451,7 +460,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"id": "c1800fb8-fb11-4c98-9622-d680dcb3d2e6",
"metadata": {},
"outputs": [],
Expand All @@ -465,7 +474,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 12,
"id": "04dca3f4-8e2c-4a7d-bbc2-cc08459c9b37",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -503,7 +512,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 13,
"id": "d795b8f1-9278-4ba8-8b79-be115b1b44f1",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -536,22 +545,18 @@
" test_dataset = gRNADataset(data_test)\n",
" test_iter = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=generate_batch)\n",
" \n",
" kf1 = GroupKFold(n_splits=10)\n",
" data_train.reset_index(drop=True, inplace=True)\n",
" splits1 = kf1.split(data_train, groups=data_train['group'])\n",
" data_train, data_valid = train_test_split(data_train, train_size=0.9, random_state=SEED)\n",
" \n",
" for j, (train_idx, valid_idx) in enumerate(splits1):\n",
" data_train1, data_valid = data_train.loc[train_idx], data_train.loc[valid_idx]\n",
" \n",
" train_dataset = gRNADataset(data_train1)\n",
" train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
" train_dataset = gRNADataset(data_train)\n",
" train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
"\n",
" valid_dataset = gRNADataset(data_valid)\n",
" valid_iter = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
" \n",
" my_optim = optim.AdamW(model.parameters(), lr=lr_dict[0])\n",
" model.apply(init_weights)\n",
" train_model(train_iter, valid_iter, test_iter, 5, i, j, version='ABEdeepoff_0504')"
" valid_dataset = gRNADataset(data_valid)\n",
" valid_iter = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
"\n",
" my_optim = optim.AdamW(model.parameters(), lr=lr_dict[0])\n",
" model.apply(init_weights)\n",
" train_model(train_iter, valid_iter, test_iter, 5, i, version='ABEdeepoff_0525')"
]
},
{
Expand Down
55 changes: 34 additions & 21 deletions CBEdeepoff.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,19 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"id": "3fcca78f-dc04-4896-8430-c8d5b91e5f95",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zcd/miniconda3/envs/pytorch/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
]
}
],
"source": [
"import pkbar\n",
"import torch\n",
Expand Down Expand Up @@ -262,12 +271,12 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 9,
"id": "addd6c5c-bd95-4c04-a258-9e4784c268cb",
"metadata": {},
"outputs": [],
"source": [
"def train_model(train_iter, valid_iter, test_iter, patience, k, j, version=None):\n",
"def train_model(train_iter, valid_iter, test_iter, patience, k, version=None):\n",
" global my_optim\n",
" global lr_dict\n",
" global lst_testing\n",
Expand Down Expand Up @@ -353,7 +362,7 @@
" valid_loss = epoch_loss_ / len(valid_iter)\n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" p = f'{path}/CBE_RNN_{k}_{j}_{round(valid_loss, 6)}.pt'\n",
" p = f'{path}/CBE_RNN_{k}_{round(valid_loss, 6)}.pt'\n",
" torch.save(model.state_dict(), p)\n",
" logger.info(f'----Testing------,{p}')\n",
" r = get_test(test_iter, model)\n",
Expand Down Expand Up @@ -431,7 +440,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"id": "d3616bac-e1ae-470d-8fea-198ef0d74f6c",
"metadata": {},
"outputs": [],
Expand All @@ -451,7 +460,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"id": "c1800fb8-fb11-4c98-9622-d680dcb3d2e6",
"metadata": {},
"outputs": [],
Expand All @@ -465,7 +474,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"id": "04dca3f4-8e2c-4a7d-bbc2-cc08459c9b37",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -503,7 +512,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"id": "d795b8f1-9278-4ba8-8b79-be115b1b44f1",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -536,23 +545,27 @@
" test_dataset = gRNADataset(data_test)\n",
" test_iter = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=generate_batch)\n",
" \n",
" kf1 = GroupKFold(n_splits=10)\n",
" data_train.reset_index(drop=True, inplace=True)\n",
" splits1 = kf1.split(data_train, groups=data_train['group'])\n",
" data_train, data_valid = train_test_split(data_train, train_size=0.9, random_state=SEED)\n",
" \n",
" for j, (train_idx, valid_idx) in enumerate(splits1):\n",
" data_train1, data_valid = data_train.loc[train_idx], data_train.loc[valid_idx]\n",
" \n",
" train_dataset = gRNADataset(data_train1)\n",
" train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
" train_dataset = gRNADataset(data_train)\n",
" train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
"\n",
" valid_dataset = gRNADataset(data_valid)\n",
" valid_iter = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
" valid_dataset = gRNADataset(data_valid)\n",
" valid_iter = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)\n",
"\n",
" my_optim = optim.AdamW(model.parameters(), lr=lr_dict[0])\n",
" model.apply(init_weights)\n",
" train_model(train_iter, valid_iter, test_iter, 5, i, j, version='CBEdeepoff_0504')"
" my_optim = optim.AdamW(model.parameters(), lr=lr_dict[0])\n",
" model.apply(init_weights)\n",
" train_model(train_iter, valid_iter, test_iter, 5, i, version='CBEdeepoff_0525')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9fba6b4c-b52d-4afc-871f-cf82e3b4b074",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 95a72fd

Please sign in to comment.