Skip to content

Commit

Permalink
Replace Model with ModelRouter
Browse files Browse the repository at this point in the history
  • Loading branch information
jenniferjiangkells committed Nov 4, 2024
1 parent 77263ec commit dd3858e
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 47 deletions.
3 changes: 1 addition & 2 deletions healthchain/io/cdaconnector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def output(self, out_data: Document) -> CdaResponse:
`overwrite` attribute of the CdaConnector instance.
"""

out_data.generate_ccd(overwrite=self.overwrite)
updated_ccd_data = out_data.get_ccd_data()
updated_ccd_data = out_data.generate_ccd(overwrite=self.overwrite)

# Update the CDA document with the results

Expand Down
2 changes: 0 additions & 2 deletions healthchain/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from healthchain.pipeline.base import BasePipeline, Pipeline
from healthchain.pipeline.components.base import BaseComponent, Component
from healthchain.pipeline.components.model import Model
from healthchain.pipeline.components.preprocessors import TextPreProcessor
from healthchain.pipeline.components.postprocessors import TextPostProcessor
from healthchain.pipeline.medicalcodingpipeline import MedicalCodingPipeline
Expand All @@ -10,7 +9,6 @@
"Pipeline",
"BaseComponent",
"Component",
"Model",
"TextPreProcessor",
"TextPostProcessor",
"MedicalCodingPipeline",
Expand Down
6 changes: 4 additions & 2 deletions healthchain/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from inspect import signature
from typing import (
Any,
Callable,
Optional,
Type,
Expand Down Expand Up @@ -75,7 +76,7 @@ def __repr__(self) -> str:
return f"[{components_repr}]"

@classmethod
def load(cls, model_path: str) -> "BasePipeline":
def load(cls, model_name: str, **model_kwargs: Any) -> "BasePipeline":
"""
Load and configure a pipeline from a given model path.
Expand All @@ -84,12 +85,13 @@ def load(cls, model_path: str) -> "BasePipeline":
Args:
model_path (str): The path to the model used for configuring the pipeline.
**model_kwargs: Additional keyword arguments for the model.
Returns:
BasePipeline: A new instance of the pipeline, configured with the given model.
"""
pipeline = cls()
pipeline.configure_pipeline(model_path)
pipeline.configure_pipeline(model_name, **model_kwargs)

return pipeline

Expand Down
2 changes: 0 additions & 2 deletions healthchain/pipeline/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from .preprocessors import TextPreProcessor
from .postprocessors import TextPostProcessor
from .model import Model
from .base import BaseComponent, Component

__all__ = [
"TextPreProcessor",
"TextPostProcessor",
"Model",
"BaseComponent",
"Component",
]
21 changes: 0 additions & 21 deletions healthchain/pipeline/components/model.py

This file was deleted.

56 changes: 41 additions & 15 deletions healthchain/pipeline/medicalcodingpipeline.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,50 @@
from typing import Any
from healthchain.io.cdaconnector import CdaConnector
from healthchain.pipeline.base import BasePipeline
from healthchain.pipeline.components.preprocessors import TextPreProcessor
from healthchain.pipeline.components.postprocessors import TextPostProcessor
from healthchain.pipeline.components.model import Model
from healthchain.pipeline.modelrouter import ModelRouter


# TODO: Implement this pipeline in full
class MedicalCodingPipeline(BasePipeline):
def configure_pipeline(self, model_path: str) -> None:
"""
A pipeline for medical coding tasks using NLP models.
This pipeline is configured to process clinical documents using a medical NLP model
for tasks like named entity recognition and linking (NER+L). It uses CDA format
for input and output handling.
Examples:
>>> # Using with SpaCy/MedCAT
>>> pipeline = MedicalCodingPipeline.load("medcatlite")
>>>
>>> # Using with Hugging Face
>>> pipeline = MedicalCodingPipeline.load(
... "bert-base-uncased",
... task="ner"
... )
>>> results = pipeline(documents)
"""

def configure_pipeline(self, model_name: str, **model_kwargs: Any) -> None:
"""
Configure the pipeline with a medical NLP model and CDA connectors.
Args:
model_name: Name or path of the model to load
**model_kwargs: Additional configuration for the model
Raises:
ValueError: If no appropriate integration can be found for the model
ImportError: If required dependencies are not installed
"""
cda_connector = CdaConnector()
self.add_input(cda_connector)
# Add preprocessing component
self.add_node(TextPreProcessor(), stage="preprocessing")

# Add NER component
model = Model(
model_path
) # TODO: should converting the CcdData be a model concern?
self.add_node(model, stage="ner+l")
try:
model = ModelRouter.get_integration(model_name, **model_kwargs)
except (ValueError, ImportError) as e:
raise type(e)(
f"Failed to configure pipeline with model '{model_name}'. Error: {str(e)}"
)

# Add postprocessing component
self.add_node(TextPostProcessor(), stage="postprocessing")
self.add_input(cda_connector)
self.add_node(model, stage="ner+l")
self.add_output(cda_connector)
61 changes: 61 additions & 0 deletions healthchain/pipeline/modelrouter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from healthchain.pipeline.components.base import BaseComponent
from healthchain.pipeline.components.integrations import (
SpacyComponent,
HuggingFaceComponent,
)
import re
from typing import Any


class ModelRouter:
"""
A router that selects the appropriate integration component based on the model name.
This is an internal utility class used by pipelines to determine which integration
to use for a given model.
"""

@staticmethod
def get_integration(model_name: str, **kwargs: Any) -> BaseComponent:
"""
Determine and return the appropriate integration component for the given model.
Args:
model_name: Name or path of the model to load
**kwargs: Additional arguments for the integration component
Returns:
An initialized integration component (SpacyComponent, HuggingFaceComponent, etc.)
"""
# SpaCy models typically follow these patterns
spacy_patterns = [
r"^en_core_.*$", # standard spacy models
r"^en_core_sci_.*$", # scispacy models
r"^.*/spacy/.*$", # local spacy model paths
r"^medcatlite$", # medcat model
]

# Hugging Face models typically include these patterns
hf_patterns = [
r"^bert-.*$",
r"^gpt-.*$",
r"^t5-.*$",
r"^distilbert-.*$",
r".*/huggingface/.*$",
]

# Check for SpaCy models
for pattern in spacy_patterns:
if re.match(pattern, model_name):
return SpacyComponent(model_name)

# Check for Hugging Face models
for pattern in hf_patterns:
if re.match(pattern, model_name):
return HuggingFaceComponent(
model=model_name,
task=kwargs.get("task", "text-classification"),
)

raise ValueError(
f"Could not determine appropriate integration for model: {model_name}"
)
2 changes: 1 addition & 1 deletion tests/pipeline/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def configure_pipeline(self, model_path: str) -> None:

@pytest.fixture
def mock_model():
with patch("healthchain.pipeline.components.model.Model") as mock:
with patch("healthchain.pipeline.modelrouter.ModelRouter.get_integration") as mock:
model_instance = mock.return_value
model_instance.return_value = Document(
data="Processed note",
Expand Down
10 changes: 8 additions & 2 deletions tests/pipeline/prebuilt/test_medicalcoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
def test_coding_pipeline(mock_cda_connector, mock_model):
with patch(
"healthchain.pipeline.medicalcodingpipeline.CdaConnector", mock_cda_connector
), patch("healthchain.pipeline.medicalcodingpipeline.Model", mock_model):
), patch(
"healthchain.pipeline.medicalcodingpipeline.ModelRouter.get_integration",
mock_model,
):
pipeline = MedicalCodingPipeline.load("./path/to/model")

# Create a sample CdaRequest
Expand Down Expand Up @@ -47,7 +50,10 @@ def test_coding_pipeline(mock_cda_connector, mock_model):

def test_full_coding_pipeline_integration(mock_model, test_cda_request):
# Use mock model object for now
with patch("healthchain.pipeline.medicalcodingpipeline.Model", mock_model):
with patch(
"healthchain.pipeline.medicalcodingpipeline.ModelRouter.get_integration",
mock_model,
):
# this load method doesn't do anything yet
pipeline = MedicalCodingPipeline.load("./path/to/production/model")

Expand Down

0 comments on commit dd3858e

Please sign in to comment.