Skip to content

Commit

Permalink
Add test for custom CA loading
Browse files Browse the repository at this point in the history
  • Loading branch information
strawgate committed Jan 9, 2025
1 parent 50c62d8 commit 79eacc0
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 20 deletions.
25 changes: 25 additions & 0 deletions tests/snapshots/test_es_gateway.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# serializer version: 1
# name: Test_Init.Test_ElasticsearchStateful.test_es8_async_init_with_tls_custom_ca
dict({
'issuer': tuple(
tuple(
tuple(
'commonName',
'Elasticsearch security auto-configuration HTTP CA',
),
),
),
'notAfter': 'Dec 18 17:55:54 2027 GMT',
'notBefore': 'Dec 18 17:55:54 2024 GMT',
'serialNumber': '25813FA4F725F5566FCF014C0B8B0973E710DF90',
'subject': tuple(
tuple(
tuple(
'commonName',
'Elasticsearch security auto-configuration HTTP CA',
),
),
),
'version': 3,
})
# ---
78 changes: 58 additions & 20 deletions tests/test_es_gateway.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the Elasticsearch Gateway."""
# noqa: F401 # pylint: disable=redefined-outer-name

import os
import ssl
from typing import Any
from unittest.mock import AsyncMock, MagicMock
Expand Down Expand Up @@ -198,11 +199,51 @@ async def test_es8_async_init_with_tls(self, es_mock_builder) -> None:

assert gateway._client is not None
node: BaseNode = gateway._client._transport.node_pool.get()
assert node is not None
assert hasattr(node, "_ssl_context")
if hasattr(node, "_ssl_context"):
assert node._ssl_context.check_hostname is True
assert node._ssl_context.verify_mode == ssl.CERT_REQUIRED
ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue]

assert ssl_context.check_hostname is True
assert ssl_context.verify_mode == ssl.CERT_REQUIRED

await gateway.stop()

async def test_es8_async_init_with_tls_custom_ca(self, es_mock_builder, snapshot) -> None:
"""Test initializing a gateway with TLS and custom ca."""

# cert is located in "certs/http_ca.crt" relative to this file, get the absolute path

current_directory = os.path.dirname(os.path.abspath(__file__))

gateway = Elasticsearch8Gateway(
gateway_settings=Gateway8Settings(
url=const.TEST_CONFIG_ENTRY_DATA_URL,
verify_certs=True,
verify_hostname=True,
ca_certs=f"{current_directory}/certs/http_ca.crt",
)
)

es_mock_builder.as_elasticsearch_8_17(with_security=True).with_correct_permissions()

assert await gateway.async_init() is None

assert gateway._client is not None
node: BaseNode = gateway._client._transport.node_pool.get()
ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue]

assert ssl_context.check_hostname is True
assert ssl_context.verify_mode == ssl.CERT_REQUIRED

ca_certs = ssl_context.get_ca_certs()

added_cert: dict | None = None

for cert in ca_certs:
if cert["serialNumber"] == "25813FA4F725F5566FCF014C0B8B0973E710DF90":
added_cert = cert
break

assert added_cert is not None
assert added_cert == snapshot

await gateway.stop()

Expand All @@ -222,11 +263,10 @@ async def test_es8_async_init_with_tls_no_hostname(self, es_mock_builder) -> Non

assert gateway._client is not None
node: BaseNode = gateway._client._transport.node_pool.get()
assert node is not None
assert hasattr(node, "_ssl_context")
if hasattr(node, "_ssl_context"):
assert node._ssl_context.check_hostname is False
assert node._ssl_context.verify_mode == ssl.CERT_REQUIRED
ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue]

assert ssl_context.check_hostname is False
assert ssl_context.verify_mode == ssl.CERT_REQUIRED

await gateway.stop()

Expand All @@ -246,11 +286,10 @@ async def test_es8_async_init_without_tls(self, es_mock_builder, snapshot) -> No

assert gateway._client is not None
node: BaseNode = gateway._client._transport.node_pool.get()
assert node is not None
assert hasattr(node, "_ssl_context")
if hasattr(node, "_ssl_context"):
assert node._ssl_context.check_hostname is False
assert node._ssl_context.verify_mode == ssl.CERT_NONE
ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue]

assert ssl_context.check_hostname is False
assert ssl_context.verify_mode == ssl.CERT_NONE

await gateway.stop()

Expand All @@ -269,11 +308,10 @@ async def test_es8_async_init_without_tls(self, es_mock_builder, snapshot) -> No

assert gateway._client is not None
node: BaseNode = gateway._client._transport.node_pool.get()
assert node is not None
assert hasattr(node, "_ssl_context")
if hasattr(node, "_ssl_context"):
assert node._ssl_context.check_hostname is False
assert node._ssl_context.verify_mode == ssl.CERT_NONE
ssl_context = node._ssl_context # type: ignore[reportAttributeAccessIssue]

assert ssl_context.check_hostname is False
assert ssl_context.verify_mode == ssl.CERT_NONE

await gateway.stop()

Expand Down

0 comments on commit 79eacc0

Please sign in to comment.