diff --git a/BREEZE.rst b/BREEZE.rst index 7cd26a03c2405..ce5f6e5d4cb8f 100644 --- a/BREEZE.rst +++ b/BREEZE.rst @@ -2444,7 +2444,7 @@ This is the current syntax for `./breeze <./breeze>`_: start all integrations. Selected integrations are not saved for future execution. One of: - cassandra kerberos mongo openldap pinot presto rabbitmq redis statsd all + cassandra kerberos mongo openldap pinot rabbitmq redis statsd trino all --init-script INIT_SCRIPT_FILE Initialization script name - Sourced from files/airflow-breeze-config. Default value diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 3b3a77190fae3..1d5f151f44b90 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -594,7 +594,7 @@ github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, jira, kerberos, kubernetes, ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, -sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, +sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, trino, vertica, virtualenv, webhdfs, winrm, yandex, zendesk .. END EXTRAS HERE @@ -661,11 +661,11 @@ apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica apache.livy http dingding http discord http -google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh +google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh,trino hashicorp google microsoft.azure google,oracle microsoft.mssql odbc -mysql amazon,presto,vertica +mysql amazon,presto,trino,vertica opsgenie http postgres amazon salesforce tableau @@ -756,7 +756,7 @@ providers. not only "green path" * Integration tests where 'local' integration with a component is possible (for example tests with - MySQL/Postgres DB/Presto/Kerberos all have integration tests which run with real, dockerised components + MySQL/Postgres DB/Trino/Kerberos all have integration tests which run with real, dockerized components * System Tests which provide end-to-end testing, usually testing together several operators, sensors, transfers connecting to a real external system diff --git a/IMAGES.rst b/IMAGES.rst index b5a309627916f..b206a4816a827 100644 --- a/IMAGES.rst +++ b/IMAGES.rst @@ -116,7 +116,7 @@ parameter to Breeze: .. code-block:: bash - ./breeze build-image --python 3.7 --additional-extras=presto \ + ./breeze build-image --python 3.7 --additional-extras=trino \ --production-image --install-airflow-version=2.0.0 @@ -163,7 +163,7 @@ You can also skip installing airflow and install it from locally provided files .. code-block:: bash - ./breeze build-image --python 3.7 --additional-extras=presto \ + ./breeze build-image --python 3.7 --additional-extras=trino \ --production-image --disable-pypi-when-building --install-from-local-files-when-building In this case you airflow and all packages (.whl files) should be placed in ``docker-context-files`` folder. diff --git a/INSTALL b/INSTALL index eeab5583f04af..46d15f62aa87b 100644 --- a/INSTALL +++ b/INSTALL @@ -106,7 +106,7 @@ github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, jira, kerberos, kubernetes, ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry, -sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica, +sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, trino, vertica, virtualenv, webhdfs, winrm, yandex, zendesk # END EXTRAS HERE diff --git a/TESTING.rst b/TESTING.rst index e73ae464ac68f..07a3e73843271 100644 --- a/TESTING.rst +++ b/TESTING.rst @@ -281,12 +281,12 @@ The following integrations are available: - Integration required for OpenLDAP hooks * - pinot - Integration required for Apache Pinot hooks - * - presto - - Integration required for Presto hooks * - rabbitmq - Integration required for Celery executor tests * - redis - Integration required for Celery executor tests + * - trino + - Integration required for Trino hooks To start the ``mongo`` integration only, enter: diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json index e4f100df684aa..eabb657d181d7 100644 --- a/airflow/providers/dependencies.json +++ b/airflow/providers/dependencies.json @@ -50,7 +50,8 @@ "presto", "salesforce", "sftp", - "ssh" + "ssh", + "trino" ], "hashicorp": [ "google" @@ -65,6 +66,7 @@ "mysql": [ "amazon", "presto", + "trino", "vertica" ], "opsgenie": [ diff --git a/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py new file mode 100644 index 0000000000000..32dc8a004b79f --- /dev/null +++ b/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py @@ -0,0 +1,150 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG using TrinoToGCSOperator. +""" +import os +import re + +from airflow import models +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateExternalTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryExecuteQueryOperator, +) +from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator +from airflow.utils.dates import days_ago + +GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", 'example-project') +GCS_BUCKET = os.environ.get("GCP_TRINO_TO_GCS_BUCKET_NAME", "test-trino-to-gcs-bucket") +DATASET_NAME = os.environ.get("GCP_TRINO_TO_GCS_DATASET_NAME", "test_trino_to_gcs_dataset") + +SOURCE_MULTIPLE_TYPES = "memory.default.test_multiple_types" +SOURCE_CUSTOMER_TABLE = "tpch.sf1.customer" + + +def safe_name(s: str) -> str: + """ + Remove invalid characters for filename + """ + return re.sub("[^0-9a-zA-Z_]+", "_", s) + + +with models.DAG( + dag_id="example_trino_to_gcs", + schedule_interval=None, # Override to match your needs + start_date=days_ago(1), + tags=["example"], +) as dag: + + create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create-dataset", dataset_id=DATASET_NAME) + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + # [START howto_operator_trino_to_gcs_basic] + trino_to_gcs_basic = TrinoToGCSOperator( + task_id="trino_to_gcs_basic", + sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json", + ) + # [END howto_operator_trino_to_gcs_basic] + + # [START howto_operator_trino_to_gcs_multiple_types] + trino_to_gcs_multiple_types = TrinoToGCSOperator( + task_id="trino_to_gcs_multiple_types", + sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json", + schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + gzip=False, + ) + # [END howto_operator_trino_to_gcs_multiple_types] + + # [START howto_operator_create_external_table_multiple_types] + create_external_table_multiple_types = BigQueryCreateExternalTableOperator( + task_id="create_external_table_multiple_types", + bucket=GCS_BUCKET, + source_objects=[f"{safe_name(SOURCE_MULTIPLE_TYPES)}.*.json"], + source_format="NEWLINE_DELIMITED_JSON", + destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}", + schema_object=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + ) + # [END howto_operator_create_external_table_multiple_types] + + read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator( + task_id="read_data_from_gcs_multiple_types", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}`", + use_legacy_sql=False, + ) + + # [START howto_operator_trino_to_gcs_many_chunks] + trino_to_gcs_many_chunks = TrinoToGCSOperator( + task_id="trino_to_gcs_many_chunks", + sql=f"select * from {SOURCE_CUSTOMER_TABLE}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_CUSTOMER_TABLE)}.{{}}.json", + schema_filename=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json", + approx_max_file_size_bytes=10_000_000, + gzip=False, + ) + # [END howto_operator_trino_to_gcs_many_chunks] + + create_external_table_many_chunks = BigQueryCreateExternalTableOperator( + task_id="create_external_table_many_chunks", + bucket=GCS_BUCKET, + source_objects=[f"{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"], + source_format="NEWLINE_DELIMITED_JSON", + destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}", + schema_object=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json", + ) + + # [START howto_operator_read_data_from_gcs_many_chunks] + read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator( + task_id="read_data_from_gcs_many_chunks", + sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}`", + use_legacy_sql=False, + ) + # [END howto_operator_read_data_from_gcs_many_chunks] + + # [START howto_operator_trino_to_gcs_csv] + trino_to_gcs_csv = TrinoToGCSOperator( + task_id="trino_to_gcs_csv", + sql=f"select * from {SOURCE_MULTIPLE_TYPES}", + bucket=GCS_BUCKET, + filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.csv", + schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json", + export_format="csv", + ) + # [END howto_operator_trino_to_gcs_csv] + + create_dataset >> trino_to_gcs_basic + create_dataset >> trino_to_gcs_multiple_types + create_dataset >> trino_to_gcs_many_chunks + create_dataset >> trino_to_gcs_csv + + trino_to_gcs_multiple_types >> create_external_table_multiple_types >> read_data_from_gcs_multiple_types + trino_to_gcs_many_chunks >> create_external_table_many_chunks >> read_data_from_gcs_many_chunks + + trino_to_gcs_basic >> delete_dataset + trino_to_gcs_csv >> delete_dataset + read_data_from_gcs_multiple_types >> delete_dataset + read_data_from_gcs_many_chunks >> delete_dataset diff --git a/airflow/providers/google/cloud/transfers/trino_to_gcs.py b/airflow/providers/google/cloud/transfers/trino_to_gcs.py new file mode 100644 index 0000000000000..e2f2306bc80cc --- /dev/null +++ b/airflow/providers/google/cloud/transfers/trino_to_gcs.py @@ -0,0 +1,210 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict, List, Tuple + +from trino.client import TrinoResult +from trino.dbapi import Cursor as TrinoCursor + +from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator +from airflow.providers.trino.hooks.trino import TrinoHook +from airflow.utils.decorators import apply_defaults + + +class _TrinoToGCSTrinoCursorAdapter: + """ + An adapter that adds additional feature to the Trino cursor. + + The implementation of cursor in the trino library is not sufficient. + The following changes have been made: + + * The poke mechanism for row. You can look at the next row without consuming it. + * The description attribute is available before reading the first row. Thanks to the poke mechanism. + * the iterator interface has been implemented. + + A detailed description of the class methods is available in + `PEP-249 `__. + """ + + def __init__(self, cursor: TrinoCursor): + self.cursor: TrinoCursor = cursor + self.rows: List[Any] = [] + self.initialized: bool = False + + @property + def description(self) -> List[Tuple]: + """ + This read-only attribute is a sequence of 7-item sequences. + + Each of these sequences contains information describing one result column: + + * ``name`` + * ``type_code`` + * ``display_size`` + * ``internal_size`` + * ``precision`` + * ``scale`` + * ``null_ok`` + + The first two items (``name`` and ``type_code``) are mandatory, the other + five are optional and are set to None if no meaningful values can be provided. + """ + if not self.initialized: + # Peek for first row to load description. + self.peekone() + return self.cursor.description + + @property + def rowcount(self) -> int: + """The read-only attribute specifies the number of rows""" + return self.cursor.rowcount + + def close(self) -> None: + """Close the cursor now""" + self.cursor.close() + + def execute(self, *args, **kwargs) -> TrinoResult: + """Prepare and execute a database operation (query or command).""" + self.initialized = False + self.rows = [] + return self.cursor.execute(*args, **kwargs) + + def executemany(self, *args, **kwargs): + """ + Prepare a database operation (query or command) and then execute it against all parameter + sequences or mappings found in the sequence seq_of_parameters. + """ + self.initialized = False + self.rows = [] + return self.cursor.executemany(*args, **kwargs) + + def peekone(self) -> Any: + """Return the next row without consuming it.""" + self.initialized = True + element = self.cursor.fetchone() + self.rows.insert(0, element) + return element + + def fetchone(self) -> Any: + """ + Fetch the next row of a query result set, returning a single sequence, or + ``None`` when no more data is available. + """ + if self.rows: + return self.rows.pop(0) + return self.cursor.fetchone() + + def fetchmany(self, size=None) -> list: + """ + Fetch the next set of rows of a query result, returning a sequence of sequences + (e.g. a list of tuples). An empty sequence is returned when no more rows are available. + """ + if size is None: + size = self.cursor.arraysize + + result = [] + for _ in range(size): + row = self.fetchone() + if row is None: + break + result.append(row) + + return result + + def __next__(self) -> Any: + """ + Return the next row from the currently executing SQL statement using the same semantics as + ``.fetchone()``. A ``StopIteration`` exception is raised when the result set is exhausted. + :return: + """ + result = self.fetchone() + if result is None: + raise StopIteration() + return result + + def __iter__(self) -> "_TrinoToGCSTrinoCursorAdapter": + """Return self to make cursors compatible to the iteration protocol""" + return self + + +class TrinoToGCSOperator(BaseSQLToGCSOperator): + """Copy data from TrinoDB to Google Cloud Storage in JSON or CSV format. + + :param trino_conn_id: Reference to a specific Trino hook. + :type trino_conn_id: str + """ + + ui_color = "#a0e08c" + + type_map = { + "BOOLEAN": "BOOL", + "TINYINT": "INT64", + "SMALLINT": "INT64", + "INTEGER": "INT64", + "BIGINT": "INT64", + "REAL": "FLOAT64", + "DOUBLE": "FLOAT64", + "DECIMAL": "NUMERIC", + "VARCHAR": "STRING", + "CHAR": "STRING", + "VARBINARY": "BYTES", + "JSON": "STRING", + "DATE": "DATE", + "TIME": "TIME", + # BigQuery don't time with timezone native. + "TIME WITH TIME ZONE": "STRING", + "TIMESTAMP": "TIMESTAMP", + # BigQuery supports a narrow range of time zones during import. + # You should use TIMESTAMP function, if you want have TIMESTAMP type + "TIMESTAMP WITH TIME ZONE": "STRING", + "IPADDRESS": "STRING", + "UUID": "STRING", + } + + @apply_defaults + def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs): + super().__init__(**kwargs) + self.trino_conn_id = trino_conn_id + + def query(self): + """Queries trino and returns a cursor to the results.""" + trino = TrinoHook(trino_conn_id=self.trino_conn_id) + conn = trino.get_conn() + cursor = conn.cursor() + self.log.info("Executing: %s", self.sql) + cursor.execute(self.sql) + return _TrinoToGCSTrinoCursorAdapter(cursor) + + def field_to_bigquery(self, field) -> Dict[str, str]: + """Convert trino field type to BigQuery field type.""" + clear_field_type = field[1].upper() + # remove type argument e.g. DECIMAL(2, 10) => DECIMAL + clear_field_type, _, _ = clear_field_type.partition("(") + new_field_type = self.type_map.get(clear_field_type, "STRING") + + return {"name": field[0], "type": new_field_type} + + def convert_type(self, value, schema_type): + """ + Do nothing. Trino uses JSON on the transport layer, so types are simple. + + :param value: Trino column value + :type value: Any + :param schema_type: BigQuery data type + :type schema_type: str + """ + return value diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 49170834fa698..6210f87869747 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -629,6 +629,10 @@ transfers: target-integration-name: Google Cloud Storage (GCS) how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/presto_to_gcs.rst python-module: airflow.providers.google.cloud.transfers.presto_to_gcs + - source-integration-name: Trino + target-integration-name: Google Cloud Storage (GCS) + how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst + python-module: airflow.providers.google.cloud.transfers.trino_to_gcs - source-integration-name: SQL target-integration-name: Google Cloud Storage (GCS) python-module: airflow.providers.google.cloud.transfers.sql_to_gcs diff --git a/airflow/providers/mysql/provider.yaml b/airflow/providers/mysql/provider.yaml index a9b408fcf0174..3b40f50f67bd4 100644 --- a/airflow/providers/mysql/provider.yaml +++ b/airflow/providers/mysql/provider.yaml @@ -52,9 +52,12 @@ transfers: - source-integration-name: Amazon Simple Storage Service (S3) target-integration-name: MySQL python-module: airflow.providers.mysql.transfers.s3_to_mysql - - source-integration-name: Snowflake + - source-integration-name: Presto target-integration-name: MySQL python-module: airflow.providers.mysql.transfers.presto_to_mysql + - source-integration-name: Trino + target-integration-name: MySQL + python-module: airflow.providers.mysql.transfers.trino_to_mysql hook-class-names: - airflow.providers.mysql.hooks.mysql.MySqlHook diff --git a/airflow/providers/mysql/transfers/trino_to_mysql.py b/airflow/providers/mysql/transfers/trino_to_mysql.py new file mode 100644 index 0000000000000..b97550e116d8f --- /dev/null +++ b/airflow/providers/mysql/transfers/trino_to_mysql.py @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Dict, Optional + +from airflow.models import BaseOperator +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.providers.trino.hooks.trino import TrinoHook +from airflow.utils.decorators import apply_defaults + + +class TrinoToMySqlOperator(BaseOperator): + """ + Moves data from Trino to MySQL, note that for now the data is loaded + into memory before being pushed to MySQL, so this operator should + be used for smallish amount of data. + + :param sql: SQL query to execute against Trino. (templated) + :type sql: str + :param mysql_table: target MySQL table, use dot notation to target a + specific database. (templated) + :type mysql_table: str + :param mysql_conn_id: source mysql connection + :type mysql_conn_id: str + :param trino_conn_id: source trino connection + :type trino_conn_id: str + :param mysql_preoperator: sql statement to run against mysql prior to + import, typically use to truncate of delete in place + of the data coming in, allowing the task to be idempotent (running + the task twice won't double load data). (templated) + :type mysql_preoperator: str + """ + + template_fields = ('sql', 'mysql_table', 'mysql_preoperator') + template_ext = ('.sql',) + template_fields_renderers = {"mysql_preoperator": "sql"} + ui_color = '#a0e08c' + + @apply_defaults + def __init__( + self, + *, + sql: str, + mysql_table: str, + trino_conn_id: str = 'trino_default', + mysql_conn_id: str = 'mysql_default', + mysql_preoperator: Optional[str] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.mysql_table = mysql_table + self.mysql_conn_id = mysql_conn_id + self.mysql_preoperator = mysql_preoperator + self.trino_conn_id = trino_conn_id + + def execute(self, context: Dict) -> None: + trino = TrinoHook(trino_conn_id=self.trino_conn_id) + self.log.info("Extracting data from Trino: %s", self.sql) + results = trino.get_records(self.sql) + + mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id) + if self.mysql_preoperator: + self.log.info("Running MySQL preoperator") + self.log.info(self.mysql_preoperator) + mysql.run(self.mysql_preoperator) + + self.log.info("Inserting rows into MySQL") + mysql.insert_rows(table=self.mysql_table, rows=results) diff --git a/airflow/providers/trino/CHANGELOG.rst b/airflow/providers/trino/CHANGELOG.rst new file mode 100644 index 0000000000000..cef7dda80708a --- /dev/null +++ b/airflow/providers/trino/CHANGELOG.rst @@ -0,0 +1,25 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Changelog +--------- + +1.0.0 +..... + +Initial version of the provider. diff --git a/airflow/providers/trino/__init__.py b/airflow/providers/trino/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/providers/trino/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/trino/hooks/__init__.py b/airflow/providers/trino/hooks/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/airflow/providers/trino/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py new file mode 100644 index 0000000000000..0914d04b32e4b --- /dev/null +++ b/airflow/providers/trino/hooks/trino.py @@ -0,0 +1,191 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from typing import Any, Iterable, Optional + +import trino +from trino.exceptions import DatabaseError +from trino.transaction import IsolationLevel + +from airflow import AirflowException +from airflow.configuration import conf +from airflow.hooks.dbapi import DbApiHook +from airflow.models import Connection + + +class TrinoException(Exception): + """Trino exception""" + + +def _boolify(value): + if isinstance(value, bool): + return value + if isinstance(value, str): + if value.lower() == 'false': + return False + elif value.lower() == 'true': + return True + return value + + +class TrinoHook(DbApiHook): + """ + Interact with Trino through trino package. + + >>> ph = TrinoHook() + >>> sql = "SELECT count(1) AS num FROM airflow.static_babynames" + >>> ph.get_records(sql) + [[340698]] + """ + + conn_name_attr = 'trino_conn_id' + default_conn_name = 'trino_default' + conn_type = 'trino' + hook_name = 'Trino' + + def get_conn(self) -> Connection: + """Returns a connection object""" + db = self.get_connection( + self.trino_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + extra = db.extra_dejson + auth = None + if db.password and extra.get('auth') == 'kerberos': + raise AirflowException("Kerberos authorization doesn't support password.") + elif db.password: + auth = trino.auth.BasicAuthentication(db.login, db.password) + elif extra.get('auth') == 'kerberos': + auth = trino.auth.KerberosAuthentication( + config=extra.get('kerberos__config', os.environ.get('KRB5_CONFIG')), + service_name=extra.get('kerberos__service_name'), + mutual_authentication=_boolify(extra.get('kerberos__mutual_authentication', False)), + force_preemptive=_boolify(extra.get('kerberos__force_preemptive', False)), + hostname_override=extra.get('kerberos__hostname_override'), + sanitize_mutual_error_response=_boolify( + extra.get('kerberos__sanitize_mutual_error_response', True) + ), + principal=extra.get('kerberos__principal', conf.get('kerberos', 'principal')), + delegate=_boolify(extra.get('kerberos__delegate', False)), + ca_bundle=extra.get('kerberos__ca_bundle'), + ) + + trino_conn = trino.dbapi.connect( + host=db.host, + port=db.port, + user=db.login, + source=db.extra_dejson.get('source', 'airflow'), + http_scheme=db.extra_dejson.get('protocol', 'http'), + catalog=db.extra_dejson.get('catalog', 'hive'), + schema=db.schema, + auth=auth, + isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value] + ) + if extra.get('verify') is not None: + # Unfortunately verify parameter is available via public API. + # The PR is merged in the trino library, but has not been released. + # See: https://github.com/trinodb/trino-python-client/pull/31 + trino_conn._http_session.verify = _boolify(extra['verify']) # pylint: disable=protected-access + + return trino_conn + + def get_isolation_level(self) -> Any: + """Returns an isolation level""" + db = self.get_connection( + self.trino_conn_id # type: ignore[attr-defined] # pylint: disable=no-member + ) + isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper() + return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) + + @staticmethod + def _strip_sql(sql: str) -> str: + return sql.strip().rstrip(';') + + def get_records(self, hql, parameters: Optional[dict] = None): + """Get a set of records from Trino""" + try: + return super().get_records(self._strip_sql(hql), parameters) + except DatabaseError as e: + raise TrinoException(e) + + def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any: + """Returns only the first row, regardless of how many rows the query returns.""" + try: + return super().get_first(self._strip_sql(hql), parameters) + except DatabaseError as e: + raise TrinoException(e) + + def get_pandas_df(self, hql, parameters=None, **kwargs): + """Get a pandas dataframe from a sql query.""" + import pandas + + cursor = self.get_cursor() + try: + cursor.execute(self._strip_sql(hql), parameters) + data = cursor.fetchall() + except DatabaseError as e: + raise TrinoException(e) + column_descriptions = cursor.description + if data: + df = pandas.DataFrame(data, **kwargs) + df.columns = [c[0] for c in column_descriptions] + else: + df = pandas.DataFrame(**kwargs) + return df + + def run( + self, + hql, + autocommit: bool = False, + parameters: Optional[dict] = None, + ) -> None: + """Execute the statement against Trino. Can be used to create views.""" + return super().run(sql=self._strip_sql(hql), parameters=parameters) + + def insert_rows( + self, + table: str, + rows: Iterable[tuple], + target_fields: Optional[Iterable[str]] = None, + commit_every: int = 0, + replace: bool = False, + **kwargs, + ) -> None: + """ + A generic way to insert a set of tuples into a table. + + :param table: Name of the target table + :type table: str + :param rows: The rows to insert into the table + :type rows: iterable of tuples + :param target_fields: The names of the columns to fill in the table + :type target_fields: iterable of strings + :param commit_every: The maximum number of rows to insert in one + transaction. Set to 0 to insert all rows in one transaction. + :type commit_every: int + :param replace: Whether to replace instead of insert + :type replace: bool + """ + if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT: + self.log.info( + 'Transactions are not enable in trino connection. ' + 'Please use the isolation_level property to enable it. ' + 'Falling back to insert all rows in one transaction.' + ) + commit_every = 0 + + super().insert_rows(table, rows, target_fields, commit_every) diff --git a/airflow/providers/trino/provider.yaml b/airflow/providers/trino/provider.yaml new file mode 100644 index 0000000000000..a59aaae6abbe1 --- /dev/null +++ b/airflow/providers/trino/provider.yaml @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +package-name: apache-airflow-providers-trino +name: Trino +description: | + `Trino `__ + +versions: + - 1.0.0 + +integrations: + - integration-name: Trino + external-doc-url: https://trino.io/docs/ + logo: /integration-logos/trino/trino-og.png + tags: [software] + +hooks: + - integration-name: Trino + python-modules: + - airflow.providers.trino.hooks.trino + +hook-class-names: + - airflow.providers.trino.hooks.trino.TrinoHook diff --git a/airflow/sensors/sql.py b/airflow/sensors/sql.py index 923af6c75d062..efe88fff06648 100644 --- a/airflow/sensors/sql.py +++ b/airflow/sensors/sql.py @@ -84,6 +84,7 @@ def _get_hook(self): 'presto', 'snowflake', 'sqlite', + 'trino', 'vertica', } if conn.conn_type not in allowed_conn_type: diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 46ace4aa7b523..0a30901fce230 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -516,6 +516,16 @@ def create_default_connections(session=None): ), session, ) + merge_conn( + Connection( + conn_id="trino_default", + conn_type="trino", + host="localhost", + schema="hive", + port=3400, + ), + session, + ) merge_conn( Connection( conn_id="vertica_default", diff --git a/breeze b/breeze index c85a5acf9ca2e..302cebbdbb285 100755 --- a/breeze +++ b/breeze @@ -819,7 +819,7 @@ function breeze::parse_arguments() { else INTEGRATIONS+=("${INTEGRATION}") fi - if [[ " ${INTEGRATIONS[*]} " =~ " presto " ]]; then + if [[ " ${INTEGRATIONS[*]} " =~ " trino " ]]; then INTEGRATIONS+=("kerberos"); fi echo diff --git a/breeze-complete b/breeze-complete index daa03bce79b81..2004a1a9ae185 100644 --- a/breeze-complete +++ b/breeze-complete @@ -25,7 +25,7 @@ _breeze_allowed_python_major_minor_versions="2.7 3.5 3.6 3.7 3.8" _breeze_allowed_backends="sqlite mysql postgres" -_breeze_allowed_integrations="cassandra kerberos mongo openldap pinot presto rabbitmq redis statsd all" +_breeze_allowed_integrations="cassandra kerberos mongo openldap pinot rabbitmq redis statsd trino all" _breeze_allowed_generate_constraints_modes="source-providers pypi-providers no-providers" # registrys is good here even if it is not correct english. We are adding s automatically to all variables _breeze_allowed_github_registrys="docker.pkg.github.com ghcr.io" diff --git a/docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst b/docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst new file mode 100644 index 0000000000000..29dc5405e7484 --- /dev/null +++ b/docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst @@ -0,0 +1,142 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +Trino to Google Cloud Storage Transfer Operator +=============================================== + +`Trino `__ is an open source, fast, distributed SQL query engine for running interactive +analytic queries against data sources of all sizes ranging from gigabytes to petabytes. Trino allows +querying data where it lives, including Hive, Cassandra, relational databases or even proprietary data stores. +A single Trino query can combine data from multiple sources, allowing for analytics across your entire +organization. + +`Google Cloud Storage `__ allows world-wide storage and retrieval of +any amount of data at any time. You can use it to store backup and +`archive data `__ as well +as a `data source for BigQuery `__. + + +Data transfer +------------- + +Transfer files between Trino and Google Storage is performed with the +:class:`~airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoToGCSOperator` operator. + +This operator has 3 required parameters: + +* ``sql`` - The SQL to execute. +* ``bucket`` - The bucket to upload to. +* ``filename`` - The filename to use as the object name when uploading to Google Cloud Storage. + A ``{}`` should be specified in the filename to allow the operator to inject file + numbers in cases where the file is split due to size. + +All parameters are described in the reference documentation - :class:`~airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoToGCSOperator`. + +An example operator call might look like this: + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_trino_to_gcs_basic] + :end-before: [END howto_operator_trino_to_gcs_basic] + +Choice of data format +^^^^^^^^^^^^^^^^^^^^^ + +The operator supports two output formats: + +* ``json`` - JSON Lines (default) +* ``csv`` + +You can specify these options by the ``export_format`` parameter. + +If you want a CSV file to be created, your operator call might look like this: + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_trino_to_gcs_csv] + :end-before: [END howto_operator_trino_to_gcs_csv] + +Generating BigQuery schema +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you set ``schema_filename`` parameter, a ``.json`` file containing the BigQuery schema fields for the table +will be dumped from the database and upload to the bucket. + +If you want to create a schema file, then an example operator call might look like this: + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_trino_to_gcs_multiple_types] + :end-before: [END howto_operator_trino_to_gcs_multiple_types] + +For more information about the BigQuery schema, please look at +`Specifying schema `__ in the Big Query documentation. + +Division of the result into multiple files +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This operator supports the ability to split large result into multiple files. The ``approx_max_file_size_bytes`` +parameters allows developers to specify the file size of the splits. By default, the file has no more +than 1 900 000 000 bytes (1900 MB) + +Check `Quotas & limits in Google Cloud Storage `__ to see the +maximum allowed file size for a single object. + +If you want to create 10 MB files, your code might look like this: + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_read_data_from_gcs_many_chunks] + :end-before: [END howto_operator_read_data_from_gcs_many_chunks] + +Querying data using the BigQuery +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The data available in Google Cloud Storage can be used by BigQuery. You can load data to BigQuery or +refer in queries directly to GCS data. For information about the loading data to the BigQuery, please look at +`Introduction to loading data from Cloud Storage `__ +in the BigQuery documentation. For information about the querying GCS data, please look at +`Querying Cloud Storage data `__ in +the BigQuery documentation. + +Airflow also has numerous operators that allow you to create the use of BigQuery. +For example, if you want to create an external table that allows you to create queries that +read data directly from GCS, then you can use :class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCreateExternalTableOperator`. +Using this operator looks like this: + +.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_create_external_table_multiple_types] + :end-before: [END howto_operator_create_external_table_multiple_types] + +For more information about the Airflow and BigQuery integration, please look at +the Python API Reference - :class:`~airflow.providers.google.cloud.operators.bigquery`. + +Reference +^^^^^^^^^ + +For further information, look at: + +* `Trino Documentation `__ + +* `Google Cloud Storage Documentation `__ diff --git a/docs/apache-airflow-providers-trino/commits.rst b/docs/apache-airflow-providers-trino/commits.rst new file mode 100644 index 0000000000000..5f0341d81e984 --- /dev/null +++ b/docs/apache-airflow-providers-trino/commits.rst @@ -0,0 +1,26 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Package apache-airflow-providers-trino +------------------------------------------------------ + +`Trino `__ + + +This is detailed commit list of changes for versions provider package: ``trino``. +For high-level changelog, see :doc:`package information including changelog `. diff --git a/docs/apache-airflow-providers-trino/index.rst b/docs/apache-airflow-providers-trino/index.rst new file mode 100644 index 0000000000000..e74c7d62e7d23 --- /dev/null +++ b/docs/apache-airflow-providers-trino/index.rst @@ -0,0 +1,43 @@ + + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +``apache-airflow-providers-trino`` +=================================== + +Content +------- + +.. toctree:: + :maxdepth: 1 + :caption: References + + Python API <_api/airflow/providers/trino/index> + +.. toctree:: + :maxdepth: 1 + :caption: Resources + + PyPI Repository + +.. THE REMINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME! + +.. toctree:: + :maxdepth: 1 + :caption: Commits + + Detailed list of commits diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst index 174edec59347d..b902868f5c164 100644 --- a/docs/apache-airflow/extra-packages-ref.rst +++ b/docs/apache-airflow/extra-packages-ref.rst @@ -54,7 +54,7 @@ python dependencies for the provided package. +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | google_auth | ``pip install 'apache-airflow[google_auth]'`` | Google auth backend | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ -| kerberos | ``pip install 'apache-airflow[kerberos]'`` | Kerberos integration for Kerberized services (Hadoop, Presto) | +| kerberos | ``pip install 'apache-airflow[kerberos]'`` | Kerberos integration for Kerberized services (Hadoop, Presto, Trino) | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ | ldap | ``pip install 'apache-airflow[ldap]'`` | LDAP authentication for users | +---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+ @@ -235,6 +235,8 @@ Those are extras that add dependencies needed for integration with other softwar +---------------------+-----------------------------------------------------+-------------------------------------------+ | singularity | ``pip install 'apache-airflow[singularity]'`` | Singularity container operator | +---------------------+-----------------------------------------------------+-------------------------------------------+ +| trino | ``pip install 'apache-airflow[trino]'`` | All Trino related operators & hooks | ++---------------------+-----------------------------------------------------+-------------------------------------------+ Other extras diff --git a/docs/exts/docs_build/errors.py b/docs/exts/docs_build/errors.py index 1a2ae0698073a..3fe9f36d810b3 100644 --- a/docs/exts/docs_build/errors.py +++ b/docs/exts/docs_build/errors.py @@ -69,7 +69,7 @@ def display_errors_summary(build_errors: Dict[str, List[DocBuildError]]) -> None console.print("-" * 30, f"[red]Error {warning_no:3}[/]", "-" * 20) console.print(error.message) console.print() - if error.file_path and error.file_path != "" and error.line_no: + if error.file_path and not error.file_path.endswith("") and error.line_no: console.print( f"File path: {os.path.relpath(error.file_path, start=DOCS_DIR)} ({error.line_no})" ) diff --git a/docs/exts/docs_build/spelling_checks.py b/docs/exts/docs_build/spelling_checks.py index f0b272267b535..24ce3f170584f 100644 --- a/docs/exts/docs_build/spelling_checks.py +++ b/docs/exts/docs_build/spelling_checks.py @@ -180,6 +180,6 @@ def _display_error(error: SpellingError): console.print(f"Suggested Spelling: '{error.suggestion}'") if error.context_line: console.print(f"Line with Error: '{error.context_line}'") - if error.line_no: + if error.file_path and not error.file_path.endswith("") and error.line_no: console.print(f"Line Number: {error.line_no}") console.print(prepare_code_snippet(error.file_path, error.line_no)) diff --git a/docs/integration-logos/trino/trino-og.png b/docs/integration-logos/trino/trino-og.png new file mode 100644 index 0000000000000..55bedf93dd346 Binary files /dev/null and b/docs/integration-logos/trino/trino-og.png differ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index f203300c76a50..70c6fb32f0a1d 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1352,6 +1352,7 @@ tooltips traceback tracebacks travis +trino trojan tsv ttl diff --git a/scripts/ci/docker-compose/integration-kerberos.yml b/scripts/ci/docker-compose/integration-kerberos.yml index d157bd684c82d..95fc8c9c79e2f 100644 --- a/scripts/ci/docker-compose/integration-kerberos.yml +++ b/scripts/ci/docker-compose/integration-kerberos.yml @@ -36,7 +36,7 @@ services: /opt/kerberos-utils/create_client.sh bob bob /root/kerberos-keytabs/airflow.keytab; /opt/kerberos-utils/create_service.sh krb5-machine-example-com airflow /root/kerberos-keytabs/airflow.keytab; - /opt/kerberos-utils/create_service.sh presto HTTP /root/kerberos-keytabs/presto.keytab; + /opt/kerberos-utils/create_service.sh trino HTTP /root/kerberos-keytabs/trino.keytab; healthcheck: test: |- python -c " diff --git a/scripts/ci/docker-compose/integration-redis.yml b/scripts/ci/docker-compose/integration-redis.yml index ab353d267ebe8..3cdf68caf18b9 100644 --- a/scripts/ci/docker-compose/integration-redis.yml +++ b/scripts/ci/docker-compose/integration-redis.yml @@ -21,7 +21,7 @@ services: image: redis:5.0.1 volumes: - /dev/urandom:/dev/random # Required to get non-blocking entropy source - - redis-db-volume:/data/presto + - redis-db-volume:/data/redis ports: - "${REDIS_HOST_PORT}:6379" healthcheck: diff --git a/scripts/ci/docker-compose/integration-presto.yml b/scripts/ci/docker-compose/integration-trino.yml similarity index 81% rename from scripts/ci/docker-compose/integration-presto.yml rename to scripts/ci/docker-compose/integration-trino.yml index 7fce2069b31c9..3f420fb61ca17 100644 --- a/scripts/ci/docker-compose/integration-presto.yml +++ b/scripts/ci/docker-compose/integration-trino.yml @@ -17,10 +17,10 @@ --- version: "2.2" services: - presto: - image: apache/airflow:presto-2020.10.08 - container_name: presto - hostname: presto + trino: + image: apache/airflow:trino-2021.04.04 + container_name: trino + hostname: trino domainname: example.com networks: @@ -40,19 +40,19 @@ services: volumes: - /dev/urandom:/dev/random # Required to get non-blocking entropy source - ../dockerfiles/krb5-kdc-server/krb5.conf:/etc/krb5.conf:ro - - presto-db-volume:/data/presto - - kerberos-keytabs:/home/presto/kerberos-keytabs + - trino-db-volume:/data/trino + - kerberos-keytabs:/home/trino/kerberos-keytabs environment: - KRB5_CONFIG=/etc/krb5.conf - KRB5_TRACE=/dev/stderr - - KRB5_KTNAME=/home/presto/kerberos-keytabs/presto.keytab + - KRB5_KTNAME=/home/trino/kerberos-keytabs/trino.keytab airflow: environment: - - INTEGRATION_PRESTO=true + - INTEGRATION_TRINO=true depends_on: - presto: + trino: condition: service_healthy volumes: - presto-db-volume: + trino-db-volume: diff --git a/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh b/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh index 30161a3f6c5c7..c92aeab70f629 100755 --- a/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh +++ b/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh @@ -29,7 +29,7 @@ Usage: ${CMDNAME} Creates an account for the service. The service name is combined with the domain to create an principal name. If your service is named -\"presto\" a principal \"presto.example.com\" will be created. +\"trino\" a principal \"trino.example.com\" will be created. The protocol can have any value, but it must be identical in the server and client configuration. For example: HTTP. diff --git a/scripts/ci/dockerfiles/presto/Dockerfile b/scripts/ci/dockerfiles/trino/Dockerfile similarity index 78% rename from scripts/ci/dockerfiles/presto/Dockerfile rename to scripts/ci/dockerfiles/trino/Dockerfile index 80ccbfd344527..080491f6a7f4e 100644 --- a/scripts/ci/dockerfiles/presto/Dockerfile +++ b/scripts/ci/dockerfiles/trino/Dockerfile @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -ARG PRESTO_VERSION="330" -FROM prestosql/presto:${PRESTO_VERSION} +ARG TRINO_VERSION="354" +FROM trinodb/trino:${TRINO_VERSION} # Obtain root privileges USER 0 @@ -23,16 +23,16 @@ USER 0 # Setup entrypoint COPY entrypoint.sh /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] -CMD ["/usr/lib/presto/bin/run-presto"] +CMD ["/usr/lib/trino/bin/run-trino"] # Expose HTTPS EXPOSE 7778 -LABEL org.apache.airflow.component="presto" -LABEL org.apache.airflow.presto.core.version="${PRESTO_VERSION}" -LABEL org.apache.airflow.airflow_bats.version="${AIRFLOW_PRESTO_VERSION}" +LABEL org.apache.airflow.component="trino" +LABEL org.apache.airflow.trino.core.version="${TRINO_VERSION}" +LABEL org.apache.airflow.airflow_trino.version="${AIRFLOW_TRINO_VERSION}" LABEL org.apache.airflow.commit_sha="${COMMIT_SHA}" LABEL maintainer="Apache Airflow Community " # Restore user -USER presto:presto +USER trino:trino diff --git a/scripts/ci/dockerfiles/presto/build_and_push.sh b/scripts/ci/dockerfiles/trino/build_and_push.sh similarity index 79% rename from scripts/ci/dockerfiles/presto/build_and_push.sh rename to scripts/ci/dockerfiles/trino/build_and_push.sh index d3cac47775b69..ea8a59d0742b0 100755 --- a/scripts/ci/dockerfiles/presto/build_and_push.sh +++ b/scripts/ci/dockerfiles/trino/build_and_push.sh @@ -21,24 +21,24 @@ DOCKERHUB_REPO=${DOCKERHUB_REPO:="airflow"} readonly DOCKERHUB_USER readonly DOCKERHUB_REPO -PRESTO_VERSION="330" -readonly PRESTO_VERSION +TRINO_VERSION="354" +readonly TRINO_VERSION -AIRFLOW_PRESTO_VERSION="2020.10.08" -readonly AIRFLOW_PRESTO_VERSION +AIRFLOW_TRINO_VERSION="2021.04.04" +readonly AIRFLOW_TRINO_VERSION COMMIT_SHA=$(git rev-parse HEAD) readonly COMMIT_SHA cd "$( dirname "${BASH_SOURCE[0]}" )" || exit 1 -TAG="${DOCKERHUB_USER}/${DOCKERHUB_REPO}:presto-${AIRFLOW_PRESTO_VERSION}" +TAG="${DOCKERHUB_USER}/${DOCKERHUB_REPO}:trino-${AIRFLOW_TRINO_VERSION}" readonly TAG docker build . \ --pull \ - --build-arg "PRESTO_VERSION=${PRESTO_VERSION}" \ - --build-arg "AIRFLOW_PRESTO_VERSION=${AIRFLOW_PRESTO_VERSION}" \ + --build-arg "TRINO_VERSION=${TRINO_VERSION}" \ + --build-arg "AIRFLOW_TRINO_VERSION=${AIRFLOW_TRINO_VERSION}" \ --build-arg "COMMIT_SHA=${COMMIT_SHA}" \ --tag "${TAG}" diff --git a/scripts/ci/dockerfiles/presto/entrypoint.sh b/scripts/ci/dockerfiles/trino/entrypoint.sh similarity index 73% rename from scripts/ci/dockerfiles/presto/entrypoint.sh rename to scripts/ci/dockerfiles/trino/entrypoint.sh index 9c8d1130beeb2..314cc5a8ee166 100755 --- a/scripts/ci/dockerfiles/presto/entrypoint.sh +++ b/scripts/ci/dockerfiles/trino/entrypoint.sh @@ -32,7 +32,7 @@ function check_service { RES=$? set -e if [[ ${RES} == 0 ]]; then - echo "${COLOR_GREEN}OK. ${COLOR_RESET}" + echo "OK." break else echo -n "." @@ -58,27 +58,29 @@ function log() { echo -e "\u001b[32m[$(date +'%Y-%m-%dT%H:%M:%S%z')]: $*\u001b[0m" } -if [ -f /tmp/presto-initiaalized ]; then +if [ -f /tmp/trino-initialized ]; then exec /bin/sh -c "$@" fi -PRESTO_CONFIG_FILE="/usr/lib/presto/default/etc/config.properties" -JVM_CONFIG_FILE="/usr/lib/presto/default/etc/jvm.config" +TRINO_CONFIG_FILE="/etc/trino/config.properties" +JVM_CONFIG_FILE="/etc/trino/jvm.config" log "Generate self-signed SSL certificate" JKS_KEYSTORE_FILE=/tmp/ssl_keystore.jks -JKS_KEYSTORE_PASS=presto +JKS_KEYSTORE_PASS=trinodb +keytool -delete --alias "trino-ssl" -keystore "${JKS_KEYSTORE_FILE}" -storepass "${JKS_KEYSTORE_PASS}" || true + keytool \ -genkeypair \ - -alias "presto-ssl" \ + -alias "trino-ssl" \ -keyalg RSA \ -keystore "${JKS_KEYSTORE_FILE}" \ -validity 10000 \ -dname "cn=Unknown, ou=Unknown, o=Unknown, c=Unknown"\ -storepass "${JKS_KEYSTORE_PASS}" -log "Set up SSL in ${PRESTO_CONFIG_FILE}" -cat << EOF >> "${PRESTO_CONFIG_FILE}" +log "Set up SSL in ${TRINO_CONFIG_FILE}" +cat << EOF >> "${TRINO_CONFIG_FILE}" http-server.https.enabled=true http-server.https.port=7778 http-server.https.keystore.path=${JKS_KEYSTORE_FILE} @@ -86,9 +88,18 @@ http-server.https.keystore.key=${JKS_KEYSTORE_PASS} node.internal-address-source=FQDN EOF +log "Set up memory limits in ${TRINO_CONFIG_FILE}" +cat << EOF >> "${TRINO_CONFIG_FILE}" +memory.heap-headroom-per-node=128MB +query.max-memory-per-node=512MB +query.max-total-memory-per-node=512MB +EOF + +sed -i "s/Xmx.*$/Xmx640M/" "${JVM_CONFIG_FILE}" + if [[ -n "${KRB5_CONFIG=}" ]]; then - log "Set up Kerberos in ${PRESTO_CONFIG_FILE}" - cat << EOF >> "${PRESTO_CONFIG_FILE}" + log "Set up Kerberos in ${TRINO_CONFIG_FILE}" + cat << EOF >> "${TRINO_CONFIG_FILE}" http-server.https.enabled=true http-server.https.port=7778 http-server.https.keystore.path=${JKS_KEYSTORE_FILE} @@ -103,16 +114,18 @@ EOF EOF fi -log "Waiting for keytab:${KRB5_KTNAME}" -check_service "Keytab" "test -f ${KRB5_KTNAME}" 30 +if [[ -n "${KRB5_CONFIG=}" ]]; then + log "Waiting for keytab:${KRB5_KTNAME}" + check_service "Keytab" "test -f ${KRB5_KTNAME}" 30 +fi -touch /tmp/presto-initiaalized +touch /tmp/trino-initialized echo "Config: ${JVM_CONFIG_FILE}" cat "${JVM_CONFIG_FILE}" -echo "Config: ${PRESTO_CONFIG_FILE}" -cat "${PRESTO_CONFIG_FILE}" +echo "Config: ${TRINO_CONFIG_FILE}" +cat "${TRINO_CONFIG_FILE}" log "Executing cmd: ${*}" exec /bin/sh -c "${@}" diff --git a/scripts/ci/libraries/_initialization.sh b/scripts/ci/libraries/_initialization.sh index c6e2b41bbcff7..bbdf116269137 100644 --- a/scripts/ci/libraries/_initialization.sh +++ b/scripts/ci/libraries/_initialization.sh @@ -179,7 +179,7 @@ function initialization::initialize_dockerhub_variables() { # Determine available integrations function initialization::initialize_available_integrations() { - export AVAILABLE_INTEGRATIONS="cassandra kerberos mongo openldap pinot presto rabbitmq redis statsd" + export AVAILABLE_INTEGRATIONS="cassandra kerberos mongo openldap pinot rabbitmq redis statsd trino" } # Needs to be declared outside of function for MacOS diff --git a/scripts/in_container/check_environment.sh b/scripts/in_container/check_environment.sh index 801477eede9c7..22c6fe58d2092 100755 --- a/scripts/in_container/check_environment.sh +++ b/scripts/in_container/check_environment.sh @@ -160,17 +160,17 @@ check_integration "MongoDB" "mongo" "run_nc mongo 27017" 50 check_integration "Redis" "redis" "run_nc redis 6379" 50 check_integration "Cassandra" "cassandra" "run_nc cassandra 9042" 50 check_integration "OpenLDAP" "openldap" "run_nc openldap 389" 50 -check_integration "Presto (HTTP)" "presto" "run_nc presto 8080" 50 -check_integration "Presto (HTTPS)" "presto" "run_nc presto 7778" 50 -check_integration "Presto (API)" "presto" \ - "curl --max-time 1 http://presto:8080/v1/info/ | grep '\"starting\":false'" 50 +check_integration "Trino (HTTP)" "trino" "run_nc trino 8080" 50 +check_integration "Trino (HTTPS)" "trino" "run_nc trino 7778" 50 +check_integration "Trino (API)" "trino" \ + "curl --max-time 1 http://trino:8080/v1/info/ | grep '\"starting\":false'" 50 check_integration "Pinot (HTTP)" "pinot" "run_nc pinot 9000" 50 CMD="curl --max-time 1 -X GET 'http://pinot:9000/health' -H 'accept: text/plain' | grep OK" -check_integration "Presto (Controller API)" "pinot" "${CMD}" 50 +check_integration "Pinot (Controller API)" "pinot" "${CMD}" 50 CMD="curl --max-time 1 -X GET 'http://pinot:9000/pinot-controller/admin' -H 'accept: text/plain' | grep GOOD" -check_integration "Presto (Controller API)" "pinot" "${CMD}" 50 +check_integration "Pinot (Controller API)" "pinot" "${CMD}" 50 CMD="curl --max-time 1 -X GET 'http://pinot:8000/health' -H 'accept: text/plain' | grep OK" -check_integration "Presto (Broker API)" "pinot" "${CMD}" 50 +check_integration "Pinot (Broker API)" "pinot" "${CMD}" 50 check_integration "RabbitMQ" "rabbitmq" "run_nc rabbitmq 5672" 50 echo "-----------------------------------------------------------------------------------------------" diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh index 4eeb83c6efbf6..f6d31b6cbb4de 100755 --- a/scripts/in_container/run_install_and_test_provider_packages.sh +++ b/scripts/in_container/run_install_and_test_provider_packages.sh @@ -95,7 +95,7 @@ function discover_all_provider_packages() { # Columns is to force it wider, so it doesn't wrap at 80 characters COLUMNS=180 airflow providers list - local expected_number_of_providers=65 + local expected_number_of_providers=66 local actual_number_of_providers actual_providers=$(airflow providers list --output yaml | grep package_name) actual_number_of_providers=$(wc -l <<<"$actual_providers") @@ -118,7 +118,7 @@ function discover_all_hooks() { group_start "Listing available hooks via 'airflow providers hooks'" COLUMNS=180 airflow providers hooks - local expected_number_of_hooks=62 + local expected_number_of_hooks=63 local actual_number_of_hooks actual_number_of_hooks=$(airflow providers hooks --output table | grep -c "| apache" | xargs) if [[ ${actual_number_of_hooks} != "${expected_number_of_hooks}" ]]; then diff --git a/setup.py b/setup.py index e1f1ebb826160..9441b6bdd96ff 100644 --- a/setup.py +++ b/setup.py @@ -454,6 +454,7 @@ def get_sphinx_theme_version() -> str: telegram = [ 'python-telegram-bot==13.0', ] +trino = ['trino'] vertica = [ 'vertica-python>=0.5.1', ] @@ -584,6 +585,7 @@ def get_sphinx_theme_version() -> str: 'ssh': ssh, 'tableau': tableau, 'telegram': telegram, + 'trino': trino, 'vertica': vertica, 'yandex': yandex, 'zendesk': zendesk, @@ -718,6 +720,7 @@ def add_extras_for_all_deprecated_aliases() -> None: 'neo4j', 'postgres', 'presto', + 'trino', 'vertica', ] @@ -933,7 +936,9 @@ def add_all_provider_packages() -> None: add_provider_packages_to_extra_requirements("devel_ci", ALL_PROVIDERS) add_provider_packages_to_extra_requirements("devel_all", ALL_PROVIDERS) add_provider_packages_to_extra_requirements("all_dbs", ALL_DB_PROVIDERS) - add_provider_packages_to_extra_requirements("devel_hadoop", ["apache.hdfs", "apache.hive", "presto"]) + add_provider_packages_to_extra_requirements( + "devel_hadoop", ["apache.hdfs", "apache.hive", "presto", "trino"] + ) class Develop(develop_orig): diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py index cf27941776889..136811dbe3e44 100644 --- a/tests/cli/commands/test_connection_command.py +++ b/tests/cli/commands/test_connection_command.py @@ -101,6 +101,10 @@ class TestCliListConnections(unittest.TestCase): 'sqlite_default', 'sqlite', ), + ( + 'trino_default', + 'trino', + ), ( 'vertica_default', 'vertica', diff --git a/tests/conftest.py b/tests/conftest.py index d828642660174..bbf617d65ed1a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,7 +128,7 @@ def pytest_addoption(parser): action="append", metavar="INTEGRATIONS", help="only run tests matching integration specified: " - "[cassandra,kerberos,mongo,openldap,presto,rabbitmq,redis]. ", + "[cassandra,kerberos,mongo,openldap,rabbitmq,redis,statsd,trino]. ", ) group.addoption( "--backend", diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py index 00880fd2e4c21..f38c8f663cdb3 100644 --- a/tests/core/test_providers_manager.py +++ b/tests/core/test_providers_manager.py @@ -83,6 +83,7 @@ 'apache-airflow-providers-ssh', 'apache-airflow-providers-tableau', 'apache-airflow-providers-telegram', + 'apache-airflow-providers-trino', 'apache-airflow-providers-vertica', 'apache-airflow-providers-yandex', 'apache-airflow-providers-zendesk', @@ -147,6 +148,7 @@ 'sqoop', 'ssh', 'tableau', + 'trino', 'vault', 'vertica', 'wasb', diff --git a/tests/providers/google/cloud/transfers/test_trino_to_gcs.py b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py new file mode 100644 index 0000000000000..7cb6539a3d846 --- /dev/null +++ b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py @@ -0,0 +1,331 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest +from unittest.mock import patch + +import pytest + +from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator + +TASK_ID = "test-trino-to-gcs" +TRINO_CONN_ID = "my-trino-conn" +GCP_CONN_ID = "my-gcp-conn" +IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +SQL = "SELECT * FROM memory.default.test_multiple_types" +BUCKET = "gs://test" +FILENAME = "test_{}.ndjson" + +NDJSON_LINES = [ + b'{"some_num": 42, "some_str": "mock_row_content_1"}\n', + b'{"some_num": 43, "some_str": "mock_row_content_2"}\n', + b'{"some_num": 44, "some_str": "mock_row_content_3"}\n', +] +CSV_LINES = [ + b"some_num,some_str\r\n", + b"42,mock_row_content_1\r\n", + b"43,mock_row_content_2\r\n", + b"44,mock_row_content_3\r\n", +] +SCHEMA_FILENAME = "schema_test.json" +SCHEMA_JSON = b'[{"name": "some_num", "type": "INT64"}, {"name": "some_str", "type": "STRING"}]' + + +@pytest.mark.integration("trino") +class TestTrinoToGCSOperator(unittest.TestCase): + def test_init(self): + """Test TrinoToGCSOperator instance is properly initialized.""" + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + impersonation_chain=IMPERSONATION_CHAIN, + ) + assert op.task_id == TASK_ID + assert op.sql == SQL + assert op.bucket == BUCKET + assert op.filename == FILENAME + assert op.impersonation_chain == IMPERSONATION_CHAIN + + @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") + @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") + def test_save_as_json(self, mock_gcs_hook, mock_trino_hook): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + assert BUCKET == bucket + assert FILENAME.format(0) == obj + assert "application/json" == mime_type + assert not gzip + with open(tmp_filename, "rb") as file: + assert b"".join(NDJSON_LINES) == file.read() + + mock_gcs_hook.return_value.upload.side_effect = _assert_upload + + mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor + + mock_cursor.return_value.description = [ + ("some_num", "INTEGER", None, None, None, None, None), + ("some_str", "VARCHAR", None, None, None, None, None), + ] + + mock_cursor.return_value.fetchone.side_effect = [ + [42, "mock_row_content_1"], + [43, "mock_row_content_2"], + [44, "mock_row_content_3"], + None, + ] + + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + trino_conn_id=TRINO_CONN_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + op.execute(None) + + mock_trino_hook.assert_called_once_with(trino_conn_id=TRINO_CONN_ID) + mock_gcs_hook.assert_called_once_with( + delegate_to=None, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_gcs_hook.return_value.upload.assert_called() + + @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") + @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") + def test_save_as_json_with_file_splitting(self, mock_gcs_hook, mock_trino_hook): + """Test that ndjson is split by approx_max_file_size_bytes param.""" + + expected_upload = { + FILENAME.format(0): b"".join(NDJSON_LINES[:2]), + FILENAME.format(1): NDJSON_LINES[2], + } + + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + assert BUCKET == bucket + assert "application/json" == mime_type + assert not gzip + with open(tmp_filename, "rb") as file: + assert expected_upload[obj] == file.read() + + mock_gcs_hook.return_value.upload.side_effect = _assert_upload + + mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor + + mock_cursor.return_value.description = [ + ("some_num", "INTEGER", None, None, None, None, None), + ("some_str", "VARCHAR(20)", None, None, None, None, None), + ] + + mock_cursor.return_value.fetchone.side_effect = [ + [42, "mock_row_content_1"], + [43, "mock_row_content_2"], + [44, "mock_row_content_3"], + None, + ] + + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]), + ) + + op.execute(None) + + mock_gcs_hook.return_value.upload.assert_called() + + @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") + @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") + def test_save_as_json_with_schema_file(self, mock_gcs_hook, mock_trino_hook): + """Test writing schema files.""" + + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument + if obj == SCHEMA_FILENAME: + with open(tmp_filename, "rb") as file: + assert SCHEMA_JSON == file.read() + + mock_gcs_hook.return_value.upload.side_effect = _assert_upload + + mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor + + mock_cursor.return_value.description = [ + ("some_num", "INTEGER", None, None, None, None, None), + ("some_str", "VARCHAR", None, None, None, None, None), + ] + + mock_cursor.return_value.fetchone.side_effect = [ + [42, "mock_row_content_1"], + [43, "mock_row_content_2"], + [44, "mock_row_content_3"], + None, + ] + + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + schema_filename=SCHEMA_FILENAME, + export_format="csv", + trino_conn_id=TRINO_CONN_ID, + gcp_conn_id=GCP_CONN_ID, + ) + op.execute(None) + + # once for the file and once for the schema + assert 2 == mock_gcs_hook.return_value.upload.call_count + + @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") + @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") + def test_save_as_csv(self, mock_trino_hook, mock_gcs_hook): + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + assert BUCKET == bucket + assert FILENAME.format(0) == obj + assert "text/csv" == mime_type + assert not gzip + with open(tmp_filename, "rb") as file: + assert b"".join(CSV_LINES) == file.read() + + mock_gcs_hook.return_value.upload.side_effect = _assert_upload + + mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor + + mock_cursor.return_value.description = [ + ("some_num", "INTEGER", None, None, None, None, None), + ("some_str", "VARCHAR", None, None, None, None, None), + ] + + mock_cursor.return_value.fetchone.side_effect = [ + [42, "mock_row_content_1"], + [43, "mock_row_content_2"], + [44, "mock_row_content_3"], + None, + ] + + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + export_format="csv", + trino_conn_id=TRINO_CONN_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + op.execute(None) + + mock_gcs_hook.return_value.upload.assert_called() + + mock_trino_hook.assert_called_once_with(trino_conn_id=TRINO_CONN_ID) + mock_gcs_hook.assert_called_once_with( + delegate_to=None, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") + @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") + def test_save_as_csv_with_file_splitting(self, mock_gcs_hook, mock_trino_hook): + """Test that csv is split by approx_max_file_size_bytes param.""" + + expected_upload = { + FILENAME.format(0): b"".join(CSV_LINES[:3]), + FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]), + } + + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): + assert BUCKET == bucket + assert "text/csv" == mime_type + assert not gzip + with open(tmp_filename, "rb") as file: + assert expected_upload[obj] == file.read() + + mock_gcs_hook.return_value.upload.side_effect = _assert_upload + + mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor + + mock_cursor.return_value.description = [ + ("some_num", "INTEGER", None, None, None, None, None), + ("some_str", "VARCHAR(20)", None, None, None, None, None), + ] + + mock_cursor.return_value.fetchone.side_effect = [ + [42, "mock_row_content_1"], + [43, "mock_row_content_2"], + [44, "mock_row_content_3"], + None, + ] + + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]), + export_format="csv", + ) + + op.execute(None) + + mock_gcs_hook.return_value.upload.assert_called() + + @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook") + @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook") + def test_save_as_csv_with_schema_file(self, mock_gcs_hook, mock_trino_hook): + """Test writing schema files.""" + + def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument + if obj == SCHEMA_FILENAME: + with open(tmp_filename, "rb") as file: + assert SCHEMA_JSON == file.read() + + mock_gcs_hook.return_value.upload.side_effect = _assert_upload + + mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor + + mock_cursor.return_value.description = [ + ("some_num", "INTEGER", None, None, None, None, None), + ("some_str", "VARCHAR", None, None, None, None, None), + ] + + mock_cursor.return_value.fetchone.side_effect = [ + [42, "mock_row_content_1"], + [43, "mock_row_content_2"], + [44, "mock_row_content_3"], + None, + ] + + op = TrinoToGCSOperator( + task_id=TASK_ID, + sql=SQL, + bucket=BUCKET, + filename=FILENAME, + schema_filename=SCHEMA_FILENAME, + export_format="csv", + ) + op.execute(None) + + # once for the file and once for the schema + assert 2 == mock_gcs_hook.return_value.upload.call_count diff --git a/tests/providers/google/cloud/transfers/test_trino_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_trino_to_gcs_system.py new file mode 100644 index 0000000000000..00d5716556183 --- /dev/null +++ b/tests/providers/google/cloud/transfers/test_trino_to_gcs_system.py @@ -0,0 +1,169 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from contextlib import closing, suppress + +import pytest + +from airflow.models import Connection +from airflow.providers.trino.hooks.trino import TrinoHook +from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGQUERY_KEY, GCP_GCS_KEY +from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context + +try: + from airflow.utils.session import create_session +except ImportError: + # This is a hack to import create_session from old destination and + # fool the pre-commit check that looks for old imports... + # TODO remove this once we don't need to test this on 1.10 + import importlib + + db_module = importlib.import_module("airflow.utils.db") + create_session = getattr(db_module, "create_session") + + +GCS_BUCKET = os.environ.get("GCP_TRINO_TO_GCS_BUCKET_NAME", "test-trino-to-gcs-bucket") +DATASET_NAME = os.environ.get("GCP_TRINO_TO_GCS_DATASET_NAME", "test_trino_to_gcs_dataset") + +CREATE_QUERY = """ +CREATE TABLE memory.default.test_multiple_types ( + -- Boolean + z_boolean BOOLEAN, + -- Integers + z_tinyint TINYINT, + z_smallint SMALLINT, + z_integer INTEGER, + z_bigint BIGINT, + -- Floating-Point + z_real REAL, + z_double DOUBLE, + -- Fixed-Point + z_decimal DECIMAL(10,2), + -- String + z_varchar VARCHAR(20), + z_char CHAR(20), + z_varbinary VARBINARY, + z_json JSON, + -- Date and Time + z_date DATE, + z_time TIME, + z_time_with_time_zone TIME WITH TIME ZONE, + z_timestamp TIMESTAMP, + z_timestamp_with_time_zone TIMESTAMP WITH TIME ZONE, + -- Network Address + z_ipaddress_v4 IPADDRESS, + z_ipaddress_v6 IPADDRESS, + -- UUID + z_uuid UUID +) +""" + +LOAD_QUERY = """ +INSERT INTO memory.default.test_multiple_types VALUES( + -- Boolean + true, -- z_boolean BOOLEAN, + -- Integers + CAST(POW(2, 7 ) - 42 AS TINYINT), -- z_tinyint TINYINT, + CAST(POW(2, 15) - 42 AS SMALLINT), -- z_smallint SMALLINT, + CAST(POW(2, 31) - 42 AS INTEGER), -- z_integer INTEGER, + CAST(POW(2, 32) - 42 AS BIGINT) * 2, -- z_bigint BIGINT, + -- Floating-Point + REAL '42', -- z_real REAL, + DOUBLE '1.03e42', -- z_double DOUBLE, + -- Floating-Point + DECIMAL '1.1', -- z_decimal DECIMAL(10, 2), + -- String + U&'Hello winter \2603 !', -- z_vaarchar VARCHAR(20), + 'cat', -- z_char CHAR(20), + X'65683F', -- z_varbinary VARBINARY, + CAST('["A", 1, true]' AS JSON), -- z_json JSON, + -- Date and Time + DATE '2001-08-22', -- z_date DATE, + TIME '01:02:03.456', -- z_time TIME, + TIME '01:02:03.456 America/Los_Angeles', -- z_time_with_time_zone TIME WITH TIME ZONE, + TIMESTAMP '2001-08-22 03:04:05.321', -- z_timestamp TIMESTAMP, + TIMESTAMP '2001-08-22 03:04:05.321 America/Los_Angeles', -- z_timestamp_with_time_zone TIMESTAMP WITH TIME + -- ZONE, + -- Network Address + IPADDRESS '10.0.0.1', -- z_ipaddress_v4 IPADDRESS, + IPADDRESS '2001:db8::1', -- z_ipaddress_v6 IPADDRESS, + -- UUID + UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59' -- z_uuid UUID +) +""" +DELETE_QUERY = "DROP TABLE memory.default.test_multiple_types" + + +@pytest.mark.integration("trino") +class TrinoToGCSSystemTest(GoogleSystemTest): + @staticmethod + def init_connection(): + with create_session() as session: + session.query(Connection).filter(Connection.conn_id == "trino_default").delete() + session.merge( + Connection( + conn_id="trino_default", conn_type="conn_type", host="trino", port=8080, login="airflow" + ) + ) + + @staticmethod + def init_db(): + hook = TrinoHook() + with hook.get_conn() as conn: + with closing(conn.cursor()) as cur: + cur.execute(CREATE_QUERY) + # Trino does not execute queries until the result is fetched. :-( + cur.fetchone() + cur.execute(LOAD_QUERY) + cur.fetchone() + + @staticmethod + def drop_db(): + hook = TrinoHook() + with hook.get_conn() as conn: + with closing(conn.cursor()) as cur: + cur.execute(DELETE_QUERY) + # Trino does not execute queries until the result is fetched. :-( + cur.fetchone() + + @provide_gcp_context(GCP_GCS_KEY) + def setUp(self): + super().setUp() + self.init_connection() + self.create_gcs_bucket(GCS_BUCKET) + with suppress(Exception): + self.drop_db() + self.init_db() + self.execute_with_ctx( + ["bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}"], + key=GCP_BIGQUERY_KEY, + ) + + @provide_gcp_context(GCP_BIGQUERY_KEY) + def test_run_example_dag(self): + self.run_dag("example_trino_to_gcs", CLOUD_DAG_FOLDER) + + @provide_gcp_context(GCP_GCS_KEY) + def tearDown(self): + self.delete_gcs_bucket(GCS_BUCKET) + self.drop_db() + self.execute_with_ctx( + ["bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}"], + key=GCP_BIGQUERY_KEY, + ) + super().tearDown() diff --git a/tests/providers/mysql/transfers/test_trino_to_mysql.py b/tests/providers/mysql/transfers/test_trino_to_mysql.py new file mode 100644 index 0000000000000..2e23169cc5f29 --- /dev/null +++ b/tests/providers/mysql/transfers/test_trino_to_mysql.py @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import unittest +from unittest.mock import patch + +from airflow.providers.mysql.transfers.trino_to_mysql import TrinoToMySqlOperator +from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + + +class TestTrinoToMySqlTransfer(TestHiveEnvironment): + def setUp(self): + self.kwargs = dict( + sql='sql', + mysql_table='mysql_table', + task_id='test_trino_to_mysql_transfer', + ) + super().setUp() + + @patch('airflow.providers.mysql.transfers.trino_to_mysql.MySqlHook') + @patch('airflow.providers.mysql.transfers.trino_to_mysql.TrinoHook') + def test_execute(self, mock_trino_hook, mock_mysql_hook): + TrinoToMySqlOperator(**self.kwargs).execute(context={}) + + mock_trino_hook.return_value.get_records.assert_called_once_with(self.kwargs['sql']) + mock_mysql_hook.return_value.insert_rows.assert_called_once_with( + table=self.kwargs['mysql_table'], rows=mock_trino_hook.return_value.get_records.return_value + ) + + @patch('airflow.providers.mysql.transfers.trino_to_mysql.MySqlHook') + @patch('airflow.providers.mysql.transfers.trino_to_mysql.TrinoHook') + def test_execute_with_mysql_preoperator(self, mock_trino_hook, mock_mysql_hook): + self.kwargs.update(dict(mysql_preoperator='mysql_preoperator')) + + TrinoToMySqlOperator(**self.kwargs).execute(context={}) + + mock_trino_hook.return_value.get_records.assert_called_once_with(self.kwargs['sql']) + mock_mysql_hook.return_value.run.assert_called_once_with(self.kwargs['mysql_preoperator']) + mock_mysql_hook.return_value.insert_rows.assert_called_once_with( + table=self.kwargs['mysql_table'], rows=mock_trino_hook.return_value.get_records.return_value + ) + + @unittest.skipIf( + 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + ) + def test_trino_to_mysql(self): + op = TrinoToMySqlOperator( + task_id='trino_to_mysql_check', + sql=""" + SELECT name, count(*) as ccount + FROM airflow.static_babynames + GROUP BY name + """, + mysql_table='test_static_babynames', + mysql_preoperator='TRUNCATE TABLE test_static_babynames;', + dag=self.dag, + ) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py index f9e85875fd8de..e6ebb737c2c09 100644 --- a/tests/providers/presto/hooks/test_presto.py +++ b/tests/providers/presto/hooks/test_presto.py @@ -206,28 +206,3 @@ def test_get_pandas_df(self): assert result_sets[1][0] == df.values.tolist()[1][0] self.cur.execute.assert_called_once_with(statement, None) - - -class TestPrestoHookIntegration(unittest.TestCase): - @pytest.mark.integration("presto") - @mock.patch.dict('os.environ', AIRFLOW_CONN_PRESTO_DEFAULT="presto://airflow@presto:8080/") - def test_should_record_records(self): - hook = PrestoHook() - sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3" - records = hook.get_records(sql) - assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records - - @pytest.mark.integration("presto") - @pytest.mark.integration("kerberos") - def test_should_record_records_with_kerberos_auth(self): - conn_url = ( - 'presto://airflow@presto:7778/?' - 'auth=kerberos&kerberos__service_name=HTTP&' - 'verify=False&' - 'protocol=https' - ) - with mock.patch.dict('os.environ', AIRFLOW_CONN_PRESTO_DEFAULT=conn_url): - hook = PrestoHook() - sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3" - records = hook.get_records(sql) - assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records diff --git a/tests/providers/trino/__init__.py b/tests/providers/trino/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/trino/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/trino/hooks/__init__.py b/tests/providers/trino/hooks/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/trino/hooks/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py new file mode 100644 index 0000000000000..e649d2bece789 --- /dev/null +++ b/tests/providers/trino/hooks/test_trino.py @@ -0,0 +1,233 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json +import re +import unittest +from unittest import mock +from unittest.mock import patch + +import pytest +from parameterized import parameterized +from trino.transaction import IsolationLevel + +from airflow import AirflowException +from airflow.models import Connection +from airflow.providers.trino.hooks.trino import TrinoHook + + +class TestTrinoHookConn(unittest.TestCase): + @patch('airflow.providers.trino.hooks.trino.trino.auth.BasicAuthentication') + @patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect') + @patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection') + def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic_auth): + mock_get_connection.return_value = Connection( + login='login', password='password', host='host', schema='hive' + ) + + conn = TrinoHook().get_conn() + mock_connect.assert_called_once_with( + catalog='hive', + host='host', + port=None, + http_scheme='http', + schema='hive', + source='airflow', + user='login', + isolation_level=0, + auth=mock_basic_auth.return_value, + ) + mock_basic_auth.assert_called_once_with('login', 'password') + assert mock_connect.return_value == conn + + @patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection') + def test_get_conn_invalid_auth(self, mock_get_connection): + mock_get_connection.return_value = Connection( + login='login', + password='password', + host='host', + schema='hive', + extra=json.dumps({'auth': 'kerberos'}), + ) + with pytest.raises( + AirflowException, match=re.escape("Kerberos authorization doesn't support password.") + ): + TrinoHook().get_conn() + + @patch('airflow.providers.trino.hooks.trino.trino.auth.KerberosAuthentication') + @patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect') + @patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection') + def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_auth): + mock_get_connection.return_value = Connection( + login='login', + host='host', + schema='hive', + extra=json.dumps( + { + 'auth': 'kerberos', + 'kerberos__config': 'TEST_KERBEROS_CONFIG', + 'kerberos__service_name': 'TEST_SERVICE_NAME', + 'kerberos__mutual_authentication': 'TEST_MUTUAL_AUTHENTICATION', + 'kerberos__force_preemptive': True, + 'kerberos__hostname_override': 'TEST_HOSTNAME_OVERRIDE', + 'kerberos__sanitize_mutual_error_response': True, + 'kerberos__principal': 'TEST_PRINCIPAL', + 'kerberos__delegate': 'TEST_DELEGATE', + 'kerberos__ca_bundle': 'TEST_CA_BUNDLE', + } + ), + ) + + conn = TrinoHook().get_conn() + mock_connect.assert_called_once_with( + catalog='hive', + host='host', + port=None, + http_scheme='http', + schema='hive', + source='airflow', + user='login', + isolation_level=0, + auth=mock_auth.return_value, + ) + mock_auth.assert_called_once_with( + ca_bundle='TEST_CA_BUNDLE', + config='TEST_KERBEROS_CONFIG', + delegate='TEST_DELEGATE', + force_preemptive=True, + hostname_override='TEST_HOSTNAME_OVERRIDE', + mutual_authentication='TEST_MUTUAL_AUTHENTICATION', + principal='TEST_PRINCIPAL', + sanitize_mutual_error_response=True, + service_name='TEST_SERVICE_NAME', + ) + assert mock_connect.return_value == conn + + @parameterized.expand( + [ + ('False', False), + ('false', False), + ('true', True), + ('true', True), + ('/tmp/cert.crt', '/tmp/cert.crt'), + ] + ) + def test_get_conn_verify(self, current_verify, expected_verify): + patcher_connect = patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect') + patcher_get_connections = patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection') + + with patcher_connect as mock_connect, patcher_get_connections as mock_get_connection: + mock_get_connection.return_value = Connection( + login='login', host='host', schema='hive', extra=json.dumps({'verify': current_verify}) + ) + mock_verify = mock.PropertyMock() + type(mock_connect.return_value._http_session).verify = mock_verify + + conn = TrinoHook().get_conn() + mock_verify.assert_called_once_with(expected_verify) + assert mock_connect.return_value == conn + + +class TestTrinoHook(unittest.TestCase): + def setUp(self): + super().setUp() + + self.cur = mock.MagicMock() + self.conn = mock.MagicMock() + self.conn.cursor.return_value = self.cur + conn = self.conn + + class UnitTestTrinoHook(TrinoHook): + conn_name_attr = 'test_conn_id' + + def get_conn(self): + return conn + + def get_isolation_level(self): + return IsolationLevel.READ_COMMITTED + + self.db_hook = UnitTestTrinoHook() + + @patch('airflow.hooks.dbapi.DbApiHook.insert_rows') + def test_insert_rows(self, mock_insert_rows): + table = "table" + rows = [("hello",), ("world",)] + target_fields = None + commit_every = 10 + self.db_hook.insert_rows(table, rows, target_fields, commit_every) + mock_insert_rows.assert_called_once_with(table, rows, None, 10) + + def test_get_first_record(self): + statement = 'SQL' + result_sets = [('row1',), ('row2',)] + self.cur.fetchone.return_value = result_sets[0] + + assert result_sets[0] == self.db_hook.get_first(statement) + self.conn.close.assert_called_once_with() + self.cur.close.assert_called_once_with() + self.cur.execute.assert_called_once_with(statement) + + def test_get_records(self): + statement = 'SQL' + result_sets = [('row1',), ('row2',)] + self.cur.fetchall.return_value = result_sets + + assert result_sets == self.db_hook.get_records(statement) + self.conn.close.assert_called_once_with() + self.cur.close.assert_called_once_with() + self.cur.execute.assert_called_once_with(statement) + + def test_get_pandas_df(self): + statement = 'SQL' + column = 'col' + result_sets = [('row1',), ('row2',)] + self.cur.description = [(column,)] + self.cur.fetchall.return_value = result_sets + df = self.db_hook.get_pandas_df(statement) + + assert column == df.columns[0] + + assert result_sets[0][0] == df.values.tolist()[0][0] + assert result_sets[1][0] == df.values.tolist()[1][0] + + self.cur.execute.assert_called_once_with(statement, None) + + +class TestTrinoHookIntegration(unittest.TestCase): + @pytest.mark.integration("trino") + @mock.patch.dict('os.environ', AIRFLOW_CONN_TRINO_DEFAULT="trino://airflow@trino:8080/") + def test_should_record_records(self): + hook = TrinoHook() + sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3" + records = hook.get_records(sql) + assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records + + @pytest.mark.integration("trino") + @pytest.mark.integration("kerberos") + def test_should_record_records_with_kerberos_auth(self): + conn_url = ( + 'trino://airflow@trino.example.com:7778/?' + 'auth=kerberos&kerberos__service_name=HTTP&' + 'verify=False&' + 'protocol=https' + ) + with mock.patch.dict('os.environ', AIRFLOW_CONN_TRINO_DEFAULT=conn_url): + hook = TrinoHook() + sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3" + records = hook.get_records(sql) + assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records