Skip to content

Commit

Permalink
Merge pull request #31 from nmervaillie/sso-support
Browse files Browse the repository at this point in the history
Ability to initialize the store with an existing driver object
  • Loading branch information
alfredorubin96 authored Jul 9, 2024
2 parents 31a171e + fe49d5a commit f78e5c9
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 110 deletions.
3 changes: 2 additions & 1 deletion docs/modules/ROOT/pages/neo4jstore.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ This class is an implementation of the rdflib link:https://rdflib.readthedocs.io
== Constructor
|===
| Name | Type | Required | Default | Description
|config|Neo4jStoreConfig|True||Neo4jStoreConfig object that contains all the useful informations to initialize the store.
|config|Neo4jStoreConfig|True||Neo4jStoreConfig object that contains all the useful information to initialize the store.
|driver|Neo4jStoreConfig|False|None|A pre-built Neo4j driver object to use to connect to the database. You cannot specify both a driver and credentials in the Neo4jStoreConfig.
|===

== Functions
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/ROOT/pages/neo4jstoreconfig.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This object is used to configure the Neo4j Store to connect to your Neo4j Instan
== Constructor
|===
| Name | Type | Required | Values(Default) | Description
| auth_data | Dictionary | True | ("uri", "database", "user", "pwd") | A dictionary containing authentication data. The required keys are: ["uri", "database", "user", "pwd"].
| auth_data | Dictionary | Yes, unless a driver object is passed in the store init | ("uri", "database", "user", "pwd") | A dictionary containing authentication data. The required keys are: ["uri", "database", "user", "pwd"].
| batching | Boolean | False | boolean (True) | A boolean indicating whether batching is enabled.
| batch_size | Integer | False | (5000) | An integer representing the batch size (The batch size is intended as number of entities to store inside the database (nodes/relationships) and not triples.
| custom_mappings | List[Tuple[Str,Str,Str]] | False | Empty list | A list of tuples containing custom mappings for prefixes in the form (prefix, object_to_replace, new_object).
Expand Down
31 changes: 20 additions & 11 deletions rdflib_neo4j/Neo4jStore.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Dict

from rdflib.store import Store
from neo4j import GraphDatabase
from neo4j import GraphDatabase, Driver
from neo4j import WRITE_ACCESS
import logging

from rdflib_neo4j.Neo4jTriple import Neo4jTriple
from rdflib_neo4j.config.Neo4jStoreConfig import Neo4jStoreConfig
from rdflib_neo4j.config.const import NEO4J_DRIVER_USER_AGENT_NAME
from rdflib_neo4j.config.utils import check_auth_data
from rdflib_neo4j.query_composers.NodeQueryComposer import NodeQueryComposer
from rdflib_neo4j.query_composers.RelationshipQueryComposer import RelationshipQueryComposer
from rdflib_neo4j.utils import handle_neo4j_driver_exception
Expand All @@ -17,11 +18,16 @@ class Neo4jStore(Store):

context_aware = True

def __init__(self, config: Neo4jStoreConfig):
def __init__(self, config: Neo4jStoreConfig, neo4j_driver: Driver = None):
self.__open = False
self.driver = None
self.driver = neo4j_driver
self.session = None
self.config = config
if not neo4j_driver:
check_auth_data(config.auth_data)
elif config.auth_data:
raise Exception("Either initialize the store with credentials or driver. You cannot do both.")

super(Neo4jStore, self).__init__(config.get_config_dict())

self.batching = config.batching
Expand Down Expand Up @@ -62,7 +68,6 @@ def close(self, commit_pending_transaction=True):
self.commit(commit_nodes=True)
self.commit(commit_rels=True)
self.session.close()
self.driver.close()
self.__set_open(False)
print(f"IMPORTED {self.total_triples} TRIPLES")
self.total_triples=0
Expand Down Expand Up @@ -147,6 +152,16 @@ def __set_open(self, val: bool):
self.__open = val
print(f"The store is now: {'Open' if self.__open else 'Closed'}")

def __get_driver(self) -> Driver:
if not self.driver:
auth_data = self.config.auth_data
self.driver = GraphDatabase.driver(
auth_data['uri'],
auth=(auth_data['user'], auth_data['pwd']),
database=auth_data.get('database', 'neo4j'),
user_agent=NEO4J_DRIVER_USER_AGENT_NAME
)
return self.driver

def __create_session(self):
"""
Expand All @@ -156,13 +171,7 @@ def __create_session(self):
"""
auth_data = self.config.auth_data
self.driver = GraphDatabase.driver(
auth_data['uri'],
auth=(auth_data['user'], auth_data['pwd']),
user_agent=NEO4J_DRIVER_USER_AGENT_NAME
)
self.session = self.driver.session(
database=auth_data.get('database', 'neo4j'),
self.session = self.__get_driver().session(
default_access_mode=WRITE_ACCESS
)

Expand Down
4 changes: 0 additions & 4 deletions rdflib_neo4j/config/Neo4jStoreConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,7 @@ def set_auth_data(self, auth):
Parameters:
- auth: A dictionary containing authentication data.
Raises:
- WrongAuthenticationException: If any of the required authentication fields is missing.
"""
check_auth_data(auth=auth)
self.auth_data = auth

def set_batching(self, val: bool):
Expand Down Expand Up @@ -225,5 +222,4 @@ def get_config_dict(self):
Raises:
- WrongAuthenticationException: If any of the required authentication fields is missing.
"""
check_auth_data(auth=self.auth_data)
return vars(self)
2 changes: 1 addition & 1 deletion test/integration/containers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from test.integration.utils import records_equal, read_file_n10s_and_rdflib
import pytest
from test.integration.fixtures import neo4j_container, neo4j_driver, graph_store, graph_store_batched, \
cleanup_databases
cleanup_databases, neo4j_connection_parameters


def test_import_person(neo4j_driver, graph_store):
Expand Down
16 changes: 8 additions & 8 deletions test/integration/custom_mappings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
from rdflib_neo4j.Neo4jStore import Neo4jStore
from rdflib_neo4j.config.Neo4jStoreConfig import Neo4jStoreConfig
from test.integration.constants import LOCAL
from test.integration.utils import records_equal, read_file_n10s_and_rdflib, get_credentials
from test.integration.utils import records_equal, read_file_n10s_and_rdflib
from rdflib_neo4j.config.const import HANDLE_VOCAB_URI_STRATEGY
import os
from dotenv import load_dotenv
from test.integration.fixtures import neo4j_container, neo4j_driver, graph_store, graph_store_batched, \
cleanup_databases
cleanup_databases, neo4j_connection_parameters


def test_custom_mapping_match(neo4j_container, neo4j_driver):
def test_custom_mapping_match(neo4j_driver, neo4j_connection_parameters):
"""
If we define a custom mapping and the strategy is HANDLE_VOCAB_URI_STRATEGY.MAP, it should match it and use the mapping
if the predicate satisfies the mapping.
"""

auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters
# Define your prefixes
prefixes = {
'neo4voc': Namespace('http://neo4j.org/vocab/sw#')
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_custom_mapping_match(neo4j_container, neo4j_driver):
assert records_equal(rels[i], rels_from_rdflib[i], rels=True)


def test_custom_mapping_no_match(neo4j_container, neo4j_driver):
def test_custom_mapping_no_match(neo4j_driver, neo4j_connection_parameters):
"""
If we define a custom mapping and the strategy is HANDLE_VOCAB_URI_STRATEGY.MAP, it shouldn't apply the mapping if the
predicate doesn't satisfy the mapping and use IGNORE as a strategy.
Expand All @@ -66,7 +66,7 @@ def test_custom_mapping_no_match(neo4j_container, neo4j_driver):
if the predicate satisfies the mapping.
"""

auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand Down Expand Up @@ -106,12 +106,12 @@ def test_custom_mapping_no_match(neo4j_container, neo4j_driver):
assert records_equal(rels[i], rels_from_rdflib[i], rels=True)


def test_custom_mapping_map_strategy_zero_custom_mappings(neo4j_container, neo4j_driver):
def test_custom_mapping_map_strategy_zero_custom_mappings(neo4j_driver, neo4j_connection_parameters):
"""
If we don't define custom mapping and the strategy is HANDLE_VOCAB_URI_STRATEGY.MAP, it shouldn't apply the mapping on anything and
just use IGNORE mode.
"""
auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand Down
41 changes: 36 additions & 5 deletions test/integration/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
from neo4j import GraphDatabase
from rdflib import Graph
from testcontainers.neo4j import Neo4jContainer

from rdflib_neo4j import HANDLE_VOCAB_URI_STRATEGY, Neo4jStoreConfig, Neo4jStore
from test.integration.constants import LOCAL, N10S_CONSTRAINT_QUERY, RDFLIB_DB
from test.integration.utils import create_graph_store
import os


Expand Down Expand Up @@ -48,13 +49,43 @@ def neo4j_driver(neo4j_container):


@pytest.fixture
def graph_store(neo4j_container, neo4j_driver):
return create_graph_store(neo4j_container)
def graph_store(neo4j_connection_parameters):
return config_graph_store(neo4j_connection_parameters)


@pytest.fixture
def graph_store_batched(neo4j_container, neo4j_driver):
return create_graph_store(neo4j_container, batching=True)
def graph_store_batched(neo4j_connection_parameters):
return config_graph_store(neo4j_connection_parameters, True)


def config_graph_store(auth_data, batching=False):

config = Neo4jStoreConfig(auth_data=auth_data,
custom_prefixes={},
custom_mappings=[],
multival_props_names=[],
handle_vocab_uri_strategy=HANDLE_VOCAB_URI_STRATEGY.IGNORE,
batching=batching)

g = Graph(store=Neo4jStore(config=config))
return g


@pytest.fixture
def neo4j_connection_parameters(neo4j_container):
if LOCAL:
auth_data = {
'uri': os.getenv("NEO4J_URI_LOCAL"),
'database': RDFLIB_DB,
'user': os.getenv("NEO4J_USER_LOCAL"),
'pwd': os.getenv("NEO4J_PWD_LOCAL")
}
else:
auth_data = {'uri': neo4j_container.get_connection_url(),
'database': RDFLIB_DB,
'user': "neo4j",
'pwd': Neo4jContainer.NEO4J_ADMIN_PASSWORD}
return auth_data


@pytest.fixture(autouse=True)
Expand Down
24 changes: 12 additions & 12 deletions test/integration/handle_vocab_uri_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
from rdflib_neo4j.config.Neo4jStoreConfig import Neo4jStoreConfig
from rdflib_neo4j.config.const import ShortenStrictException, HANDLE_VOCAB_URI_STRATEGY
from test.integration.constants import LOCAL
from test.integration.utils import records_equal, read_file_n10s_and_rdflib, get_credentials
from test.integration.utils import records_equal, read_file_n10s_and_rdflib
import pytest
from test.integration.fixtures import neo4j_container, neo4j_driver, graph_store, graph_store_batched, \
from test.integration.fixtures import neo4j_container, neo4j_connection_parameters, neo4j_driver, graph_store, graph_store_batched, \
cleanup_databases


def test_shorten_all_prefixes_defined(neo4j_container, neo4j_driver):
def test_shorten_all_prefixes_defined(neo4j_driver, neo4j_connection_parameters):
"""
If we use the strategy HANDLE_VOCAB_URI_STRATEGY.SHORTEN and we provide all the required namespaces,
it should load all the data without raising an error for a missing prefix
"""
auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand Down Expand Up @@ -61,8 +61,8 @@ def test_shorten_all_prefixes_defined(neo4j_container, neo4j_driver):
assert records_equal(rels[i], rels_from_rdflib[i], rels=True)


def test_shorten_missing_prefix(neo4j_container, neo4j_driver):
auth_data = get_credentials(LOCAL, neo4j_container)
def test_shorten_missing_prefix(neo4j_driver, neo4j_connection_parameters):
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand Down Expand Up @@ -90,8 +90,8 @@ def test_shorten_missing_prefix(neo4j_container, neo4j_driver):
assert True


def test_keep_strategy(neo4j_container, neo4j_driver):
auth_data = get_credentials(LOCAL, neo4j_container)
def test_keep_strategy(neo4j_driver, neo4j_connection_parameters):
auth_data = neo4j_connection_parameters

config = Neo4jStoreConfig(auth_data=auth_data,
handle_vocab_uri_strategy=HANDLE_VOCAB_URI_STRATEGY.KEEP,
Expand All @@ -111,8 +111,8 @@ def test_keep_strategy(neo4j_container, neo4j_driver):
assert records_equal(rels[i], rels_from_rdflib[i], rels=True)


def test_ignore_strategy(neo4j_container, neo4j_driver):
auth_data = get_credentials(LOCAL, neo4j_container)
def test_ignore_strategy(neo4j_driver, neo4j_connection_parameters):
auth_data = neo4j_connection_parameters

config = Neo4jStoreConfig(auth_data=auth_data,
handle_vocab_uri_strategy=HANDLE_VOCAB_URI_STRATEGY.IGNORE,
Expand All @@ -132,8 +132,8 @@ def test_ignore_strategy(neo4j_container, neo4j_driver):
assert records_equal(rels[i], rels_from_rdflib[i], rels=True)


def test_ignore_strategy_on_json_ld_file(neo4j_container, neo4j_driver):
auth_data = get_credentials(LOCAL, neo4j_container)
def test_ignore_strategy_on_json_ld_file(neo4j_driver, neo4j_connection_parameters):
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand Down
20 changes: 10 additions & 10 deletions test/integration/multival_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
from rdflib import Graph, Namespace
from rdflib_neo4j.Neo4jStore import Neo4jStore
from rdflib_neo4j.config.Neo4jStoreConfig import Neo4jStoreConfig
from test.integration.utils import records_equal, read_file_n10s_and_rdflib, create_graph_store, get_credentials
from test.integration.utils import records_equal, read_file_n10s_and_rdflib
from rdflib_neo4j.config.const import HANDLE_VOCAB_URI_STRATEGY, HANDLE_MULTIVAL_STRATEGY
import pytest
from test.integration.fixtures import neo4j_container, neo4j_driver, graph_store, graph_store_batched, \
cleanup_databases
cleanup_databases, neo4j_connection_parameters


def test_read_file_multival_with_strategy_no_predicates(neo4j_container, neo4j_driver):
def test_read_file_multival_with_strategy_no_predicates(neo4j_driver, neo4j_connection_parameters):
"""Compare data imported with n10s procs and n10s + rdflib in single add mode for multivalues"""

auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {}
Expand Down Expand Up @@ -40,9 +40,9 @@ def test_read_file_multival_with_strategy_no_predicates(neo4j_container, neo4j_d
assert records_equal(records[i], records_from_rdf_lib[i])


def test_read_file_multival_with_strategy_and_predicates(neo4j_container, neo4j_driver):
def test_read_file_multival_with_strategy_and_predicates(neo4j_driver, neo4j_connection_parameters):
"""Compare data imported with n10s procs and n10s + rdflib in single add mode for multivalues"""
auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand Down Expand Up @@ -72,9 +72,9 @@ def test_read_file_multival_with_strategy_and_predicates(neo4j_container, neo4j_
assert records_equal(records[i], records_from_rdf_lib[i])


def test_read_file_multival_with_no_strategy_and_predicates(neo4j_container, neo4j_driver):
def test_read_file_multival_with_no_strategy_and_predicates(neo4j_driver, neo4j_connection_parameters):
"""Compare data imported with n10s procs and n10s + rdflib in single add mode for multivalues"""
auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

# Define your prefixes
prefixes = {
Expand All @@ -101,9 +101,9 @@ def test_read_file_multival_with_no_strategy_and_predicates(neo4j_container, neo
for i in range(len(records)):
assert records_equal(records[i], records_from_rdf_lib[i])

def test_read_file_multival_array_as_set_behavior(neo4j_container, neo4j_driver):
def test_read_file_multival_array_as_set_behavior(neo4j_driver, neo4j_connection_parameters):
"""When importing the data, if a triple will add the same value to a multivalued property it won't be added"""
auth_data = get_credentials(LOCAL, neo4j_container)
auth_data = neo4j_connection_parameters

prefixes = {'music': Namespace('neo4j://graph.schema#')}

Expand Down
2 changes: 1 addition & 1 deletion test/integration/single_triple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from test.integration.constants import GET_DATA_QUERY, RDFLIB_DB
import pytest
from test.integration.fixtures import neo4j_container, neo4j_driver, graph_store, graph_store_batched, \
cleanup_databases
cleanup_databases, neo4j_connection_parameters


def test_import_type_as_label(neo4j_driver, graph_store):
Expand Down
Loading

0 comments on commit f78e5c9

Please sign in to comment.