diff --git a/tests/snapshots/test_es_gateway.ambr b/tests/snapshots/test_es_gateway.ambr new file mode 100644 index 0000000..8915fd2 --- /dev/null +++ b/tests/snapshots/test_es_gateway.ambr @@ -0,0 +1,25 @@ +# serializer version: 1 +# name: Test_Init.Test_ElasticsearchStateful.test_es8_async_init_with_tls_custom_ca + dict({ + 'issuer': tuple( + tuple( + tuple( + 'commonName', + 'Elasticsearch security auto-configuration HTTP CA', + ), + ), + ), + 'notAfter': 'Dec 18 17:55:54 2027 GMT', + 'notBefore': 'Dec 18 17:55:54 2024 GMT', + 'serialNumber': '25813FA4F725F5566FCF014C0B8B0973E710DF90', + 'subject': tuple( + tuple( + tuple( + 'commonName', + 'Elasticsearch security auto-configuration HTTP CA', + ), + ), + ), + 'version': 3, + }) +# --- diff --git a/tests/test_es_gateway.py b/tests/test_es_gateway.py index bc32f2a..2f7c806 100644 --- a/tests/test_es_gateway.py +++ b/tests/test_es_gateway.py @@ -1,6 +1,7 @@ """Tests for the Elasticsearch Gateway.""" # noqa: F401 # pylint: disable=redefined-outer-name +import os import ssl from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -198,11 +199,51 @@ async def test_es8_async_init_with_tls(self, es_mock_builder) -> None: assert gateway._client is not None node: BaseNode = gateway._client._transport.node_pool.get() - assert node is not None - assert hasattr(node, "_ssl_context") - if hasattr(node, "_ssl_context"): - assert node._ssl_context.check_hostname is True - assert node._ssl_context.verify_mode == ssl.CERT_REQUIRED + ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue] + + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + + await gateway.stop() + + async def test_es8_async_init_with_tls_custom_ca(self, es_mock_builder, snapshot) -> None: + """Test initializing a gateway with TLS and custom ca.""" + + # cert is located in "certs/http_ca.crt" relative to this file, get the absolute path + + current_directory = os.path.dirname(os.path.abspath(__file__)) + + gateway = Elasticsearch8Gateway( + gateway_settings=Gateway8Settings( + url=const.TEST_CONFIG_ENTRY_DATA_URL, + verify_certs=True, + verify_hostname=True, + ca_certs=f"{current_directory}/certs/http_ca.crt", + ) + ) + + es_mock_builder.as_elasticsearch_8_17(with_security=True).with_correct_permissions() + + assert await gateway.async_init() is None + + assert gateway._client is not None + node: BaseNode = gateway._client._transport.node_pool.get() + ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue] + + assert ssl_context.check_hostname is True + assert ssl_context.verify_mode == ssl.CERT_REQUIRED + + ca_certs = ssl_context.get_ca_certs() + + added_cert: dict | None = None + + for cert in ca_certs: + if cert["serialNumber"] == "25813FA4F725F5566FCF014C0B8B0973E710DF90": + added_cert = cert + break + + assert added_cert is not None + assert added_cert == snapshot await gateway.stop() @@ -222,11 +263,10 @@ async def test_es8_async_init_with_tls_no_hostname(self, es_mock_builder) -> Non assert gateway._client is not None node: BaseNode = gateway._client._transport.node_pool.get() - assert node is not None - assert hasattr(node, "_ssl_context") - if hasattr(node, "_ssl_context"): - assert node._ssl_context.check_hostname is False - assert node._ssl_context.verify_mode == ssl.CERT_REQUIRED + ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue] + + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode == ssl.CERT_REQUIRED await gateway.stop() @@ -246,11 +286,10 @@ async def test_es8_async_init_without_tls(self, es_mock_builder, snapshot) -> No assert gateway._client is not None node: BaseNode = gateway._client._transport.node_pool.get() - assert node is not None - assert hasattr(node, "_ssl_context") - if hasattr(node, "_ssl_context"): - assert node._ssl_context.check_hostname is False - assert node._ssl_context.verify_mode == ssl.CERT_NONE + ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue] + + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode == ssl.CERT_NONE await gateway.stop() @@ -269,11 +308,10 @@ async def test_es8_async_init_without_tls(self, es_mock_builder, snapshot) -> No assert gateway._client is not None node: BaseNode = gateway._client._transport.node_pool.get() - assert node is not None - assert hasattr(node, "_ssl_context") - if hasattr(node, "_ssl_context"): - assert node._ssl_context.check_hostname is False - assert node._ssl_context.verify_mode == ssl.CERT_NONE + ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue] + + assert ssl_context.check_hostname is False + assert ssl_context.verify_mode == ssl.CERT_NONE await gateway.stop()