diff --git a/run_localGPT.py b/run_localGPT.py index 5f01a4e7..24a1978a 100644 --- a/run_localGPT.py +++ b/run_localGPT.py @@ -130,7 +130,7 @@ def get_embeddings(): if "instructor" in EMBEDDING_MODEL_NAME: return HuggingFaceInstructEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, embed_instruction='Represent the document for retrieval:', query_instruction='Represent the question for retrieving supporting documents:' ) @@ -138,14 +138,14 @@ def get_embeddings(): elif "bge" in EMBEDDING_MODEL_NAME: return HuggingFaceBgeEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, query_instruction='Represent this sentence for searching relevant passages:' ) else: return HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL_NAME, - model_kwargs={"device": compute_device}, + model_kwargs={"device": device_type}, ) embeddings = get_embeddings() logging.info(f"Loaded embeddings from {EMBEDDING_MODEL_NAME}")