Skip to content

Commit

Permalink
fix: Restore spaCy model selection and config file functionality
Browse files Browse the repository at this point in the history
This commit addresses regressions from previous changes and improves configuration management:

- Reintroduce ability to change active spaCy model using `topos set --spacy trf`
- Move config.yaml to user's system config directory
- Add toposSetupHook in flake.nix to initialize config file if not present
- Update utilities to use TOPOS_CONFIG_PATH environment variable
- Modify OntologicalFeatureDetection and token_classifiers to load spaCy model from config
- Remove direct spaCy model download from spacy_loader.py
- Update get_config_path utility function

These changes restore the flexibility of spaCy model selection and improve
the overall configuration management of the application.
  • Loading branch information
G-structure committed Oct 16, 2024
1 parent cad580b commit c0faf55
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 25 deletions.
1 change: 1 addition & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
active_spacy_model: en_core_web_trf
21 changes: 21 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@
);
};

configFile = pkgs.copyPathToStore ./config.yaml;
yq = pkgs.yq-go;

# Note: This only loads the settings from the repos config file
# if one is not already set in the user's .config directory.
toposSetupHook = ''
export TOPOS_CONFIG_PATH="$HOME/.config/topos/config.yaml"
mkdir -p "$(dirname "$TOPOS_CONFIG_PATH")"
if [ ! -f "$TOPOS_CONFIG_PATH" ]; then
echo "Creating new config file at $TOPOS_CONFIG_PATH"
echo "# Topos Configuration" > "$TOPOS_CONFIG_PATH"
${yq}/bin/yq eval ${configFile} | while IFS= read -r line; do
echo "$line" >> "$TOPOS_CONFIG_PATH"
done
echo "Config file created at $TOPOS_CONFIG_PATH"
else
echo "Config file already exists at $TOPOS_CONFIG_PATH"
fi
'';
in
{
packages = {
Expand All @@ -67,6 +86,7 @@

shellHook = ''
export PATH="${pkgs.myapp}/bin:$PATH"
${toposSetupHook}
'';
};

Expand All @@ -80,6 +100,7 @@

shellHook = ''
export PATH="${pkgs.myapp}/bin:$PATH"
${toposSetupHook}
topos run
'';
};
Expand Down
123 changes: 122 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ pystray = "0.19.5"
supabase = "^2.6.0"
psycopg2-binary = "^2.9.9"
en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"}
en-core-web-lg = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.8.0/en_core_web_lg-3.8.0-py3-none-any.whl"}
en-core-web-md = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.8.0/en_core_web_md-3.8.0-py3-none-any.whl"}
en-core-web-trf = {url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl"}
[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
Expand Down
8 changes: 4 additions & 4 deletions topos/FC/ontological_feature_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import datetime

from topos.services.database.app_state import AppState
from topos.utilities.utils import get_root_directory
from topos.utilities.utils import get_config_path
import os
import yaml

Expand All @@ -37,11 +37,11 @@ def __init__(self, neo4j_uri, neo4j_user, neo4j_password, neo4j_database_name, u
self.tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
self.model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")
# Assuming the config.yaml is in ./topos/ relative to setup.py directory
config_path = os.path.join(get_root_directory(), 'config.yaml')
config_path = get_config_path()

with open(config_path, 'r') as file:
settings = yaml.safe_load(file)

spacy_model_name = settings.get('active_spacy_model')

# Load SpaCy models
Expand Down Expand Up @@ -639,4 +639,4 @@ def get_connected_nodes(self, node_id, edges):
# paragraph = (
# "John, a software engineer from New York, bought a new laptop from Amazon on Saturday. "
# "He later met with his friend Alice, who is a data scientist at Google, for coffee at Starbucks. "
# "They discussed a variety of topics including the recent advancements in arti
# "They discussed a variety of topics including the recent advancements in arti
1 change: 0 additions & 1 deletion topos/config.yaml

This file was deleted.

16 changes: 5 additions & 11 deletions topos/downloaders/spacy_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import subprocess
import yaml
import os
from ..utilities.utils import get_python_command, get_root_directory

from ..utilities.utils import get_config_path

def download_spacy_model(model_selection):
if model_selection == 'small':
Expand All @@ -15,17 +13,13 @@ def download_spacy_model(model_selection):
model_name = "en_core_web_trf"
else: #default
model_name = "en_core_web_sm"

python_command = get_python_command()


# Define the path to the config.yaml file
config_path = os.path.join(get_root_directory(), 'config.yaml')
config_path = get_config_path()
try:
subprocess.run([python_command, '-m', 'spacy', 'download', model_name], check=True)
# Write updated settings to YAML file
with open(config_path, 'w') as file:
yaml.dump({'active_spacy_model': model_name}, file)
print(f"Successfully downloaded '{model_name}' spaCy model.")
print(f"'{model_name}' set as active model.")
except subprocess.CalledProcessError as e:
print(f"Error downloading '{model_name}' spaCy model: {e}")
except Exception as e:
print(f"An error occurred setting config.yaml: {e}")
17 changes: 14 additions & 3 deletions topos/services/basic_analytics/token_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,18 @@

import spacy
from spacy.tokens import Token
import en_core_web_sm
import yaml

from topos.utilities.utils import get_config_path

# Assuming the config.yaml is in ./topos/ relative to setup.py directory
config_path = get_config_path()

with open(config_path, 'r') as file:
settings = yaml.safe_load(file)

# Load the spacy model setting
model_name = settings.get('active_spacy_model')

def get_token_sent(token):
'''
Expand All @@ -12,8 +23,8 @@ def get_token_sent(token):
return token_span.sent

# Now you can use `model_name` in your code
print(f"[ mem-loader :: Using spaCy model: en_core_web_sm ]")
nlp = en_core_web_sm.load()
print(f"[ mem-loader :: Using spaCy model: {model_name} ]")
nlp = spacy.load(model_name)
Token.set_extension('sent', getter=get_token_sent, force = True)

def get_entity_dict(doc):
Expand Down
14 changes: 9 additions & 5 deletions topos/utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import shutil


def get_python_command():
if shutil.which("python"):
return "python"
Expand All @@ -12,21 +11,26 @@ def get_python_command():
else:
raise EnvironmentError("No Python interpreter found")

def get_config_path():
config_path = os.getenv('TOPOS_CONFIG_PATH')
if not config_path:
raise EnvironmentError("TOPOS_CONFIG_PATH environment variable is not set")
return config_path

def get_root_directory():
# Get the current file's directory
current_file_directory = os.path.dirname(os.path.abspath(__file__))

# Find the first occurrence of "topos" from the right
topos_index = current_file_directory.rfind("topos")

if topos_index != -1:
# Get the path up to the first "topos" directory
base_topos_directory = current_file_directory[:topos_index + len("topos")]
return base_topos_directory
else:
raise ValueError("The 'topos' directory was not found in the path.")

def parse_json(data):
import json
return json.loads(data)
Expand Down Expand Up @@ -57,4 +61,4 @@ def generate_hex_code(n_digits):
return ''.join(random.choice('0123456789ABCDEF') for _ in range(n_digits))

def generate_deci_code(n_digits):
return ''.join(random.choice('0123456789') for _ in range(n_digits))
return ''.join(random.choice('0123456789') for _ in range(n_digits))

0 comments on commit c0faf55

Please sign in to comment.