Skip to content

Commit b93c661

Browse files
committed
Modify lm similarity example notebook to work without cuda
1 parent 7a34fcf commit b93c661

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

notebooks/lm_similarity_example.ipynb

+20-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,22 @@
77
"source": [
88
"# Language Model Similarity Example\n",
99
"\n",
10-
"This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model."
10+
"This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model.\n",
11+
"\n",
12+
"You will need to install the huggingface transformers and pytorch libraries for this notebook to run, please use\n",
13+
"```\n",
14+
"pip install transformers torch\n",
15+
"```"
16+
]
17+
},
18+
{
19+
"cell_type": "code",
20+
"execution_count": null,
21+
"id": "bb6a33c6-e0f2-414f-9356-97f6fb47e2b9",
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"import torch"
1126
]
1227
},
1328
{
@@ -21,7 +36,7 @@
2136
"source": [
2237
"# change these constants as needed based on your hardware constraints\n",
2338
"BATCH_SIZE = 16\n",
24-
"DEVICE = \"cuda\"\n",
39+
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
2540
"MODEL_NAME = \"bert-base-uncased\""
2641
]
2742
},
@@ -175,7 +190,7 @@
175190
"\n",
176191
"dataset = fetch_20newsgroups(subset=\"train\")\n",
177192
"df = pd.DataFrame({\"text\": dataset[\"data\"], \"category\": [dataset[\"target_names\"][i] for i in dataset[\"target\"]]})\n",
178-
"#df = df.iloc[0:1999]\n",
193+
"df = df.iloc[0:1999] # NOTE: if running on CPU or weaker GPU, recommend uncommenting this to avoid long processing times on first BERT anchor creation.\n",
179194
"df.head()"
180195
]
181196
},
@@ -196,7 +211,7 @@
196211
},
197212
"outputs": [],
198213
"source": [
199-
"icat.initialize(offline=True)"
214+
"icat.initialize(offline=False)"
200215
]
201216
},
202217
{
@@ -279,7 +294,7 @@
279294
"name": "python",
280295
"nbconvert_exporter": "python",
281296
"pygments_lexer": "ipython3",
282-
"version": "3.10.12"
297+
"version": "3.10.15"
283298
}
284299
},
285300
"nbformat": 4,

0 commit comments

Comments
 (0)