Skip to content

Commit 1323233

Browse files
committed
Add examples to walkthrough notebook for how to predict on new data
1 parent f2532d9 commit 1323233

File tree

1 file changed

+119
-1
lines changed

1 file changed

+119
-1
lines changed

notebooks/usage_walkthrough.ipynb

+119-1
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@
710710
{
711711
"cell_type": "code",
712712
"execution_count": null,
713-
"id": "102ddd7d-9421-4043-8b6c-5c98abf0a5c6",
713+
"id": "96972e35-62fe-49bf-a825-357e3addbe54",
714714
"metadata": {},
715715
"outputs": [],
716716
"source": [
@@ -727,6 +727,124 @@
727727
"source": [
728728
"It's definitely not stellar, but this is a solid starting point given that we only gave it 12 labeled points and 5 features, involving a total of only 15 keywords."
729729
]
730+
},
731+
{
732+
"cell_type": "markdown",
733+
"id": "be982bb2-a74d-4696-a5dc-9834017bcbe1",
734+
"metadata": {},
735+
"source": [
736+
"While one of the main usecases for this tool is to interactively filter down a dataset, since this is training a model under the hood, the resulting model can in principle now be applied to other datasets. As an example, we demonstrate on the testing subset of the 20 newsgroups data:"
737+
]
738+
},
739+
{
740+
"cell_type": "code",
741+
"execution_count": null,
742+
"id": "d77f39d9-9434-4c3e-9aa3-0da6a2f69a28",
743+
"metadata": {},
744+
"outputs": [],
745+
"source": [
746+
"dataset_test = fetch_20newsgroups(subset=\"test\")\n",
747+
"df_test = pd.DataFrame({\"text\": dataset_test[\"data\"], \"category\": [dataset_test[\"target_names\"][i] for i in dataset_test[\"target\"]]})\n",
748+
"df_test.category.value_counts()"
749+
]
750+
},
751+
{
752+
"cell_type": "markdown",
753+
"id": "7afccd3e-bbe1-433d-bed5-67453114ae11",
754+
"metadata": {},
755+
"source": [
756+
"Running the model on new data can be done in a few ways:\n",
757+
"\n",
758+
"1. call `model.predict()` and pass in the dataframe to run predictions on"
759+
]
760+
},
761+
{
762+
"cell_type": "code",
763+
"execution_count": null,
764+
"id": "6c3c62f0-7bb6-4952-bb63-236e20f7bd57",
765+
"metadata": {},
766+
"outputs": [],
767+
"source": [
768+
"preds = model.predict(df_test)\n",
769+
"np.where(preds >= .5)"
770+
]
771+
},
772+
{
773+
"cell_type": "markdown",
774+
"id": "3ea92fd0-4215-4026-8418-3719311ca22a",
775+
"metadata": {},
776+
"source": [
777+
"2. Set the active data in the model to the new dataframe and explore/label/analyze in the interface as normal, using `model.data.set_data()`"
778+
]
779+
},
780+
{
781+
"cell_type": "code",
782+
"execution_count": null,
783+
"id": "3d3bddc5-d9ac-43ee-9517-95ec66ae48af",
784+
"metadata": {},
785+
"outputs": [],
786+
"source": [
787+
"model.data.set_data(df_test)"
788+
]
789+
},
790+
{
791+
"cell_type": "code",
792+
"execution_count": null,
793+
"id": "8ae5e890-7676-4d04-bcc8-66e500c73e75",
794+
"metadata": {},
795+
"outputs": [],
796+
"source": [
797+
"model.data.active_data[model.data.active_data._pred >= .5]"
798+
]
799+
},
800+
{
801+
"cell_type": "markdown",
802+
"id": "97f36c4e-cd5e-4c3c-af54-d3fb2d986241",
803+
"metadata": {},
804+
"source": [
805+
"<div style=\"margin-left: 20px; background-color: #00796B; color: white; padding: 10px;\">\n",
806+
"Note: although the interface no longer has the original dataset anymore, this is \"non-destructive\" to the model - all labeled data is separately copied into <code>model.training_data</code>, so additional points can be labeled, new features can be added, etc. without losing any of the \"original\" signal.\n",
807+
"</div>"
808+
]
809+
},
810+
{
811+
"cell_type": "code",
812+
"execution_count": null,
813+
"id": "d671eb78-4fd4-4006-a442-7ae449f5cadb",
814+
"metadata": {},
815+
"outputs": [],
816+
"source": [
817+
"model.training_data"
818+
]
819+
},
820+
{
821+
"cell_type": "markdown",
822+
"id": "84a2065c-9a34-4fe6-98a2-4e81c1268f46",
823+
"metadata": {},
824+
"source": [
825+
"3. The underlying scikit learn model (a logistic regression model) is accessible at `model.classifier`, so you could in principle directly use this (notably assumes you separately are \"featurizing\" any new data yourself, or using `model.featurize()`)"
826+
]
827+
},
828+
{
829+
"cell_type": "code",
830+
"execution_count": null,
831+
"id": "d9fef6ab-ea0d-4978-82fe-2872aedbca5c",
832+
"metadata": {},
833+
"outputs": [],
834+
"source": [
835+
"model.classifier, model.classifier.coef_"
836+
]
837+
},
838+
{
839+
"cell_type": "code",
840+
"execution_count": null,
841+
"id": "48ec3627-1c4b-4705-a652-44894d20387f",
842+
"metadata": {},
843+
"outputs": [],
844+
"source": [
845+
"featurized_df = model.featurize(df_test, normalize=False).drop(columns=[\"text\", \"category\"])\n",
846+
"np.where(model.classifier.predict_proba(featurized_df)[:,1] >= .5)"
847+
]
730848
}
731849
],
732850
"metadata": {

0 commit comments

Comments
 (0)