Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: define schemas as classes #99

Merged
merged 9 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ jobs:
run: poetry install

- name: Run tests
env:
HUGGINGFACE_KEY: ${{ secrets.HF_API_TOKEN }}
run: poetry run pytest
22 changes: 11 additions & 11 deletions datastew/process/jsonl_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

from datastew.embedding import EmbeddingModel
from datastew.repository import WeaviateRepository
from datastew.repository.weaviate_schema import (concept_schema,
mapping_schema_user_vectors,
terminology_schema)
from datastew.repository.weaviate_schema import (
concept_schema,
mapping_schema_user_vectors,
terminology_schema,
)


class WeaviateJsonlConverter(object):
Expand All @@ -20,15 +22,15 @@ class WeaviateJsonlConverter(object):
def __init__(
self,
dest_dir: str,
terminology_schema: dict = terminology_schema,
concept_schema: dict = concept_schema,
mapping_schema: dict = mapping_schema_user_vectors,
terminology_schema: dict = terminology_schema.schema,
concept_schema: dict = concept_schema.schema,
mapping_schema: dict = mapping_schema_user_vectors.schema,
buffer_size: int = 1000,
):
self.dest_dir = dest_dir
self.terminology_schema = terminology_schema
self.concept_schema = concept_schema
self.mapping_schema_user_vectors = mapping_schema
self.mapping_schema = mapping_schema
self._buffer = []
self._buffer_size = buffer_size
self._ensure_directories_exist()
Expand Down Expand Up @@ -107,9 +109,7 @@ def from_repository(self, repository: WeaviateRepository) -> None:

# Process mapping last
mapping_file_path = self._get_file_path("mapping")
for mapping in repository.get_iterator(
self.mapping_schema_user_vectors["class"]
):
for mapping in repository.get_iterator(self.mapping_schema["class"]):
self._write_to_jsonl(
mapping_file_path, self._weaviate_object_to_dict(mapping)
)
Expand Down Expand Up @@ -184,7 +184,7 @@ def from_ohdsi(self, src: str, embedding_model: EmbeddingModel):
mapping_uuid = generate_uuid5({"text": concept_names[i]})
mappings.append(
{
"class": self.mapping_schema_user_vectors["class"],
"class": self.mapping_schema["class"],
"id": mapping_uuid,
"properties": {
"text": concept_names[i],
Expand Down
95 changes: 58 additions & 37 deletions datastew/repository/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@
from datastew.repository.model import MappingResult
from datastew.repository.pagination import Page
from datastew.repository.weaviate_schema import (
concept_schema, mapping_schema_preconfigured_embeddings,
mapping_schema_user_vectors, terminology_schema)
ConceptSchema,
MappingSchema,
TerminologySchema,
concept_schema,
mapping_schema_preconfigured_embeddings,
mapping_schema_user_vectors,
terminology_schema,
)


class WeaviateRepository(BaseRepository):
Expand All @@ -34,36 +40,44 @@ def __init__(
port: int = 80,
http_port: int = 8079,
grpc_port: int = 50050,
terminology_schema: TerminologySchema = terminology_schema,
concept_schema: ConceptSchema = concept_schema,
mapping_schema: MappingSchema = mapping_schema_user_vectors,
):
"""Initialize the WeaviateRepository instance, connecting to either a local or remote Weaviate instance and
setting up the appropriate schemas based on the specific options.

:param use_weaviate_vectorizer: Specifies whether to use pre-configured embeddings (True) or custom vectors
provided by the user (False). Defaults to False.
:param huggingface_key: API key for Hugging Face if using pre-configured embeddings. Required if `use_weaviate_vectorizer`
is True. Defaults to None.
:param huggingface_key: API key for Hugging Face if using pre-configured embeddings. Required if
`use_weaviate_vectorizer` is True. Defaults to None.
:param mode: Defines the connection mode for the repository. Can be either "memory" (in-memory), or "remote"
(remote Weaviate instance). Defaults to "memory".
:param path: The path for the local disk connection, used only in "memory" mode. Defaults to "db".
:param port: The port number for remote Weaviate connection, used only in "remote" mode. Defaults to 80.
:param http_port: The HTTP port for the local connection in "memory" mode. Defaults to 8079.
:param grpc_port: The gRPC port for the local connection in "memory" mode. Defaults to 50050.

:param terminology_schema: Terminology schema to use for the repository. Defaults to pre-configured
`terminology_schema`.
:param concept_schema: Concept schema to use for the repository. Defaults to pre-configured `concept_schema`.
:param mapping_schema: Mapping schema to use for the repository. Defaults to pre-configured
`mapping_schema_user_vectors`. If `use_weaviate_vectorizer` is set to True and `mapping_schema` does not
include a `vectorizer_config` key, the schema will default to pre-configured
`mapping_schema_preconfigured_embeddings`.
:raises ValueError: If the `huggingface_key` is not provided when `use_weaviate_vectorizer` is True or if an
invalid `mode` is specified.
:raises RuntimeError: If there is a failure in creating the schema or connecting to Weaviate.
"""
self.use_weaviate_vectorizer = use_weaviate_vectorizer
self.mode = mode
self.terminology_schema = terminology_schema
self.concept_schema = concept_schema
self.mapping_schema = mapping_schema
self.client: Optional[WeaviateClient] = None
self.headers = None
if self.use_weaviate_vectorizer:
if huggingface_key:
self.headers = {"X-HuggingFace-Api-Key": huggingface_key}
else:
raise ValueError(
"A HuggingFace API key is required for generating vectors."
)
if self.mode == "memory":
self._connect_to_memory(path, http_port, grpc_port)
elif self.mode == "remote":
Expand All @@ -74,14 +88,20 @@ def __init__(
)

try:
self._create_schema_if_not_exists(terminology_schema)
self._create_schema_if_not_exists(concept_schema)
if not self.use_weaviate_vectorizer:
self._create_schema_if_not_exists(mapping_schema_user_vectors)
else:
self._create_schema_if_not_exists(
mapping_schema_preconfigured_embeddings
self._create_schema_if_not_exists(self.terminology_schema.schema)
self._create_schema_if_not_exists(self.concept_schema.schema)
if (
self.use_weaviate_vectorizer
and not self.mapping_schema.schema["vectorizer_config"]
):
self.logger.warning(
"Provided mapping schema lacks `vectorizer_config` even though"
"`use_weaviate_vectorizer` is set to True. Defaulting to"
f"{mapping_schema_preconfigured_embeddings.schema}"
)
self.mapping_schema = mapping_schema_preconfigured_embeddings
self._create_schema_if_not_exists(self.mapping_schema.schema)

except Exception as e:
raise RuntimeError(f"Failed to create schema: {e}")

Expand Down Expand Up @@ -142,7 +162,7 @@ def import_data_dictionary(
mapping = Mapping(
concept=concept,
text=description,
embedding=variable_to_embedding[variable],
embedding=list(variable_to_embedding[variable]),
sentence_embedder=embedding_model_name,
)
else:
Expand Down Expand Up @@ -171,17 +191,18 @@ def store_all(
self.store(instance)

def get_iterator(self, collection: Literal["Concept", "Mapping", "Terminology"]):
if collection == "Concept":
return_references = QueryReference(link_on="hasTerminology")
elif collection == "Mapping":
return_references = QueryReference(link_on="hasConcept")
elif collection == "Terminology":
return_references = None
else:
raise ValueError(f"Collection {collection} is not supported.")
return self.client.collections.get(collection).iterator(
include_vector=True, return_references=return_references
)
if self.client:
if collection == "Concept":
return_references = QueryReference(link_on="hasTerminology")
elif collection == "Mapping":
return_references = QueryReference(link_on="hasConcept")
elif collection == "Terminology":
return_references = None
else:
raise ValueError(f"Collection {collection} is not supported.")
return self.client.collections.get(collection).iterator(
include_vector=True, return_references=return_references
)

def get_all_sentence_embedders(self) -> List[str]:
"""Retrieves the names of all sentence embedders used in the "Mapping" collection. If
Expand Down Expand Up @@ -251,7 +272,7 @@ def get_concepts(
concept_collection = self.client.collections.get("Concept")

total_count = (
self.client.collections.get(concept_schema["class"])
self.client.collections.get(self.concept_schema.schema["class"])
.aggregate.over_all(total_count=True)
.total_count
)
Expand Down Expand Up @@ -500,7 +521,7 @@ def get_mappings(

# Fetch the total count of mappings for pagination
total_count = (
self.client.collections.get(mapping_schema_user_vectors["class"])
self.client.collections.get(self.mapping_schema.schema["class"])
.aggregate.over_all(total_count=True)
.total_count
)
Expand Down Expand Up @@ -658,7 +679,7 @@ def get_closest_mappings_with_similarities(
embedding: Sequence[float],
sentence_embedder: Optional[str] = None,
limit=5,
) -> List[MappingResult]:
) -> Sequence[MappingResult]:
"""Fetches the closest mappings based on an embedding vector and includes similarity scores for each mapping.

:param embedding: The embedding vector to find the closest mappings.
Expand Down Expand Up @@ -686,7 +707,7 @@ def get_terminology_and_model_specific_closest_mappings(
terminology_name: str,
sentence_embedder_name: str,
limit: int = 5,
) -> List[Mapping]:
) -> Sequence[Mapping]:
"""Fetches the closest mappings for a given terminology and sentence embedder model.

This function is deprecated and will be removed in a future release. It is recommended to use
Expand Down Expand Up @@ -717,7 +738,7 @@ def get_terminology_and_model_specific_closest_mappings_with_similarities(
terminology_name: str,
sentence_embedder_name: str,
limit: int = 5,
) -> List[MappingResult]:
) -> Sequence[MappingResult]:
"""Fetches the closest mappings for a given terminology and sentence embedder model, includes similarity scores.

This function is deprecated and will be removed in a future release. It is recommended to use
Expand Down Expand Up @@ -859,9 +880,7 @@ def import_from_jsonl(

# Validate essential fields
if "id" not in item or "properties" not in item:
raise ValueError(
f"Missing 'id' or 'properties' on line {idx}"
)
raise ValueError(f"Missing 'id' or 'properties' on line {idx}")
chunk.append(item)
if len(chunk) >= chunk_size:
self._process_batch(chunk, collection)
Expand Down Expand Up @@ -1037,7 +1056,9 @@ def _connect_to_memory(self, path: str, http_port: int, grpc_port: int):
)
else:
self.client = weaviate.connect_to_embedded(
persistence_data_path=path, headers=self.headers
persistence_data_path=path,
headers=self.headers,
environment_variables={"ENABLE_MODULES": "text2vec-ollama"},
)
except Exception as e:
raise ConnectionError(f"Failed to initialize Weaviate client: {e}")
Expand Down
127 changes: 84 additions & 43 deletions datastew/repository/weaviate_schema.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,87 @@
from weaviate.classes.config import (Configure, DataType, Property,
ReferenceProperty)

terminology_schema = {
"class": "Terminology",
"description": "A terminology entry",
"properties": [Property(name="name", data_type=DataType.TEXT)],
}

concept_schema = {
"class": "Concept",
"description": "A concept entry",
"properties": [
Property(name="conceptID", data_type=DataType.TEXT),
Property(name="prefLabel", data_type=DataType.TEXT),
],
"references": [
ReferenceProperty(name="hasTerminology", target_collection="Terminology")
],
}

mapping_schema_user_vectors = {
"class": "Mapping",
"description": "A mapping entry",
"properties": [
Property(name="text", data_type=DataType.TEXT),
Property(name="hasSentenceEmbedder", data_type=DataType.TEXT),
],
"references": [ReferenceProperty(name="hasConcept", target_collection="Concept")],
}

mapping_schema_preconfigured_embeddings = {
"class": "Mapping",
"description": "A mapping entry",
"properties": [
Property(name="text", data_type=DataType.TEXT),
],
"references": [ReferenceProperty(name="hasConcept", target_collection="Concept")],
"vectorizer_config": [
Configure.NamedVectors.text2vec_huggingface(
name="sentence_transformers_all_mpnet_base_v2",
from typing import List, Optional

from weaviate.classes.config import Configure, DataType, Property, ReferenceProperty
from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate


class WeaviateSchema:
def __init__(
self,
name: str,
description: str,
properties: Optional[List[Property]] = None,
references: Optional[List[ReferenceProperty]] = None,
vectorizer_config: Optional[List[_NamedVectorConfigCreate]] = None,
):
self.class_name = name
self.description = description
self.properties = properties
self.references = references
self.vectorizer_config = vectorizer_config
self.schema = {
"class": self.class_name,
"description": self.description,
"properties": self.properties,
"references": self.references,
"vectorizer_config": self.vectorizer_config,
}


class TerminologySchema(WeaviateSchema):
def __init__(
self,
name: str = "Terminology",
description: str = "A terminology entry",
properties: List[Property] = [Property(name="name", data_type=DataType.TEXT)],
):
super().__init__(name, description, properties)


class ConceptSchema(WeaviateSchema):
def __init__(
self,
name: str = "Concept",
description: str = "A concept entry",
properties: List[Property] = [
Property(name="conceptID", data_type=DataType.TEXT),
Property(name="prefLabel", data_type=DataType.TEXT),
],
references: List[ReferenceProperty] = [
ReferenceProperty(name="hasTerminology", target_collection="Terminology")
],
):
super().__init__(name, description, properties, references)


class MappingSchema(WeaviateSchema):
def __init__(
self,
name: str = "Mapping",
description: str = "A mapping entry",
properties: List[Property] = [
Property(name="text", data_type=DataType.TEXT),
Property(name="hasSentenceEmbedder", data_type=DataType.TEXT),
],
references: List[ReferenceProperty] = [
ReferenceProperty(name="hasConcept", target_collection="Concept")
],
vectorizer_config: Optional[List[_NamedVectorConfigCreate]] = None,
):
super().__init__(name, description, properties, references, vectorizer_config)


terminology_schema = TerminologySchema()
concept_schema = ConceptSchema()
mapping_schema_user_vectors = MappingSchema()
mapping_schema_preconfigured_embeddings = MappingSchema(
properties=[Property(name="text", data_type=DataType.TEXT)],
vectorizer_config=[
Configure.NamedVectors.text2vec_ollama(
name="nomic_embed_text",
source_properties=["text"],
model="sentence-transformers/all-mpnet-base-v2",
api_endpoint="http://localhost:11434",
model="nomic-embed-text",
vectorize_collection_name=False,
)
],
}
)
Loading