|
7 | 7 | "source": [
|
8 | 8 | "# Language Model Similarity Example\n",
|
9 | 9 | "\n",
|
10 |
| - "This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model." |
| 10 | + "This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model.\n", |
| 11 | + "\n", |
| 12 | + "You will need to install the huggingface transformers and pytorch libraries for this notebook to run, please use\n", |
| 13 | + "```\n", |
| 14 | + "pip install transformers torch\n", |
| 15 | + "```" |
| 16 | + ] |
| 17 | + }, |
| 18 | + { |
| 19 | + "cell_type": "code", |
| 20 | + "execution_count": null, |
| 21 | + "id": "bb6a33c6-e0f2-414f-9356-97f6fb47e2b9", |
| 22 | + "metadata": {}, |
| 23 | + "outputs": [], |
| 24 | + "source": [ |
| 25 | + "import torch" |
11 | 26 | ]
|
12 | 27 | },
|
13 | 28 | {
|
|
21 | 36 | "source": [
|
22 | 37 | "# change these constants as needed based on your hardware constraints\n",
|
23 | 38 | "BATCH_SIZE = 16\n",
|
24 |
| - "DEVICE = \"cuda\"\n", |
| 39 | + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", |
25 | 40 | "MODEL_NAME = \"bert-base-uncased\""
|
26 | 41 | ]
|
27 | 42 | },
|
|
175 | 190 | "\n",
|
176 | 191 | "dataset = fetch_20newsgroups(subset=\"train\")\n",
|
177 | 192 | "df = pd.DataFrame({\"text\": dataset[\"data\"], \"category\": [dataset[\"target_names\"][i] for i in dataset[\"target\"]]})\n",
|
178 |
| - "#df = df.iloc[0:1999]\n", |
| 193 | + "df = df.iloc[0:1999] # NOTE: if running on CPU or weaker GPU, recommend uncommenting this to avoid long processing times on first BERT anchor creation.\n", |
179 | 194 | "df.head()"
|
180 | 195 | ]
|
181 | 196 | },
|
|
196 | 211 | },
|
197 | 212 | "outputs": [],
|
198 | 213 | "source": [
|
199 |
| - "icat.initialize(offline=True)" |
| 214 | + "icat.initialize(offline=False)" |
200 | 215 | ]
|
201 | 216 | },
|
202 | 217 | {
|
|
279 | 294 | "name": "python",
|
280 | 295 | "nbconvert_exporter": "python",
|
281 | 296 | "pygments_lexer": "ipython3",
|
282 |
| - "version": "3.10.12" |
| 297 | + "version": "3.10.15" |
283 | 298 | }
|
284 | 299 | },
|
285 | 300 | "nbformat": 4,
|
|
0 commit comments