Skip to content

Commit dad2439

Browse files
damian3031hashhar
authored andcommitted
Fix logic for setting http_scheme
1 parent 84df329 commit dad2439

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

tests/unit/test_dbapi.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from unittest.mock import patch
1515

1616
import httpretty
17+
import pytest
1718
from httpretty import httprettified
1819
from requests import Session
1920

@@ -314,3 +315,26 @@ def test_description_is_none_when_cursor_is_not_executed():
314315
connection = Connection("sample_trino_cluster:443")
315316
with connection.cursor() as cursor:
316317
assert hasattr(cursor, 'description')
318+
319+
320+
@pytest.mark.parametrize(
321+
"host, port, http_scheme_input_argument, http_scheme_set",
322+
[
323+
# Infer from hostname
324+
("https://mytrinoserver.domain:9999", None, None, constants.HTTPS),
325+
("http://mytrinoserver.domain:9999", None, None, constants.HTTP),
326+
# Infer from port
327+
("mytrinoserver.domain", constants.DEFAULT_TLS_PORT, None, constants.HTTPS),
328+
("mytrinoserver.domain", constants.DEFAULT_PORT, None, constants.HTTP),
329+
# http_scheme parameter has higher precedence than port parameter
330+
("mytrinoserver.domain", constants.DEFAULT_TLS_PORT, constants.HTTP, constants.HTTP),
331+
("mytrinoserver.domain", constants.DEFAULT_PORT, constants.HTTPS, constants.HTTPS),
332+
# Set explicitly by http_scheme parameter
333+
("mytrinoserver.domain", None, constants.HTTPS, constants.HTTPS),
334+
# Default
335+
("mytrinoserver.domain", None, None, constants.HTTP),
336+
],
337+
)
338+
def test_setting_http_scheme(host, port, http_scheme_input_argument, http_scheme_set):
339+
connection = Connection(host, port, http_scheme=http_scheme_input_argument)
340+
assert connection.http_scheme == http_scheme_set

trino/dbapi.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
schema=constants.DEFAULT_SCHEMA,
151151
session_properties=None,
152152
http_headers=None,
153-
http_scheme=constants.HTTP,
153+
http_scheme=None,
154154
auth=constants.DEFAULT_AUTH,
155155
extra_credential=None,
156156
max_attempts=constants.DEFAULT_MAX_ATTEMPTS,
@@ -202,7 +202,18 @@ def __init__(
202202
else:
203203
self._http_session = http_session
204204
self.http_headers = http_headers
205-
self.http_scheme = http_scheme if not parsed_host.scheme else parsed_host.scheme
205+
206+
# Set http_scheme
207+
if parsed_host.scheme:
208+
self.http_scheme = parsed_host.scheme
209+
elif http_scheme:
210+
self.http_scheme = http_scheme
211+
elif port == constants.DEFAULT_TLS_PORT:
212+
self.http_scheme = constants.HTTPS
213+
elif port == constants.DEFAULT_PORT:
214+
self.http_scheme = constants.HTTP
215+
else:
216+
self.http_scheme = constants.HTTP
206217

207218
# Infer connection port: `hostname` takes precedence over explicit `port` argument
208219
# If none is given, use default based on HTTP protocol

0 commit comments

Comments
 (0)