diff --git a/BREEZE.rst b/BREEZE.rst
index 7cd26a03c2405..ce5f6e5d4cb8f 100644
--- a/BREEZE.rst
+++ b/BREEZE.rst
@@ -2444,7 +2444,7 @@ This is the current syntax for `./breeze <./breeze>`_:
start all integrations. Selected integrations are not saved for future execution.
One of:
- cassandra kerberos mongo openldap pinot presto rabbitmq redis statsd all
+ cassandra kerberos mongo openldap pinot rabbitmq redis statsd trino all
--init-script INIT_SCRIPT_FILE
Initialization script name - Sourced from files/airflow-breeze-config. Default value
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 3b3a77190fae3..1d5f151f44b90 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -594,7 +594,7 @@ github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap,
jira, kerberos, kubernetes, ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql,
mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, papermill, password, pinot, plexus,
postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry,
-sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica,
+sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, trino, vertica,
virtualenv, webhdfs, winrm, yandex, zendesk
.. END EXTRAS HERE
@@ -661,11 +661,11 @@ apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica
apache.livy http
dingding http
discord http
-google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh
+google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh,trino
hashicorp google
microsoft.azure google,oracle
microsoft.mssql odbc
-mysql amazon,presto,vertica
+mysql amazon,presto,trino,vertica
opsgenie http
postgres amazon
salesforce tableau
@@ -756,7 +756,7 @@ providers.
not only "green path"
* Integration tests where 'local' integration with a component is possible (for example tests with
- MySQL/Postgres DB/Presto/Kerberos all have integration tests which run with real, dockerised components
+ MySQL/Postgres DB/Trino/Kerberos all have integration tests which run with real, dockerized components
* System Tests which provide end-to-end testing, usually testing together several operators, sensors,
transfers connecting to a real external system
diff --git a/IMAGES.rst b/IMAGES.rst
index b5a309627916f..b206a4816a827 100644
--- a/IMAGES.rst
+++ b/IMAGES.rst
@@ -116,7 +116,7 @@ parameter to Breeze:
.. code-block:: bash
- ./breeze build-image --python 3.7 --additional-extras=presto \
+ ./breeze build-image --python 3.7 --additional-extras=trino \
--production-image --install-airflow-version=2.0.0
@@ -163,7 +163,7 @@ You can also skip installing airflow and install it from locally provided files
.. code-block:: bash
- ./breeze build-image --python 3.7 --additional-extras=presto \
+ ./breeze build-image --python 3.7 --additional-extras=trino \
--production-image --disable-pypi-when-building --install-from-local-files-when-building
In this case you airflow and all packages (.whl files) should be placed in ``docker-context-files`` folder.
diff --git a/INSTALL b/INSTALL
index eeab5583f04af..46d15f62aa87b 100644
--- a/INSTALL
+++ b/INSTALL
@@ -106,7 +106,7 @@ github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap,
jira, kerberos, kubernetes, ldap, microsoft.azure, microsoft.mssql, microsoft.winrm, mongo, mssql,
mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, papermill, password, pinot, plexus,
postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry,
-sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, vertica,
+sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, trino, vertica,
virtualenv, webhdfs, winrm, yandex, zendesk
# END EXTRAS HERE
diff --git a/TESTING.rst b/TESTING.rst
index e73ae464ac68f..07a3e73843271 100644
--- a/TESTING.rst
+++ b/TESTING.rst
@@ -281,12 +281,12 @@ The following integrations are available:
- Integration required for OpenLDAP hooks
* - pinot
- Integration required for Apache Pinot hooks
- * - presto
- - Integration required for Presto hooks
* - rabbitmq
- Integration required for Celery executor tests
* - redis
- Integration required for Celery executor tests
+ * - trino
+ - Integration required for Trino hooks
To start the ``mongo`` integration only, enter:
diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json
index e4f100df684aa..eabb657d181d7 100644
--- a/airflow/providers/dependencies.json
+++ b/airflow/providers/dependencies.json
@@ -50,7 +50,8 @@
"presto",
"salesforce",
"sftp",
- "ssh"
+ "ssh",
+ "trino"
],
"hashicorp": [
"google"
@@ -65,6 +66,7 @@
"mysql": [
"amazon",
"presto",
+ "trino",
"vertica"
],
"opsgenie": [
diff --git a/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py b/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
new file mode 100644
index 0000000000000..32dc8a004b79f
--- /dev/null
+++ b/airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
@@ -0,0 +1,150 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG using TrinoToGCSOperator.
+"""
+import os
+import re
+
+from airflow import models
+from airflow.providers.google.cloud.operators.bigquery import (
+ BigQueryCreateEmptyDatasetOperator,
+ BigQueryCreateExternalTableOperator,
+ BigQueryDeleteDatasetOperator,
+ BigQueryExecuteQueryOperator,
+)
+from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator
+from airflow.utils.dates import days_ago
+
+GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", 'example-project')
+GCS_BUCKET = os.environ.get("GCP_TRINO_TO_GCS_BUCKET_NAME", "test-trino-to-gcs-bucket")
+DATASET_NAME = os.environ.get("GCP_TRINO_TO_GCS_DATASET_NAME", "test_trino_to_gcs_dataset")
+
+SOURCE_MULTIPLE_TYPES = "memory.default.test_multiple_types"
+SOURCE_CUSTOMER_TABLE = "tpch.sf1.customer"
+
+
+def safe_name(s: str) -> str:
+ """
+ Remove invalid characters for filename
+ """
+ return re.sub("[^0-9a-zA-Z_]+", "_", s)
+
+
+with models.DAG(
+ dag_id="example_trino_to_gcs",
+ schedule_interval=None, # Override to match your needs
+ start_date=days_ago(1),
+ tags=["example"],
+) as dag:
+
+ create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create-dataset", dataset_id=DATASET_NAME)
+
+ delete_dataset = BigQueryDeleteDatasetOperator(
+ task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True
+ )
+
+ # [START howto_operator_trino_to_gcs_basic]
+ trino_to_gcs_basic = TrinoToGCSOperator(
+ task_id="trino_to_gcs_basic",
+ sql=f"select * from {SOURCE_MULTIPLE_TYPES}",
+ bucket=GCS_BUCKET,
+ filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json",
+ )
+ # [END howto_operator_trino_to_gcs_basic]
+
+ # [START howto_operator_trino_to_gcs_multiple_types]
+ trino_to_gcs_multiple_types = TrinoToGCSOperator(
+ task_id="trino_to_gcs_multiple_types",
+ sql=f"select * from {SOURCE_MULTIPLE_TYPES}",
+ bucket=GCS_BUCKET,
+ filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.json",
+ schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json",
+ gzip=False,
+ )
+ # [END howto_operator_trino_to_gcs_multiple_types]
+
+ # [START howto_operator_create_external_table_multiple_types]
+ create_external_table_multiple_types = BigQueryCreateExternalTableOperator(
+ task_id="create_external_table_multiple_types",
+ bucket=GCS_BUCKET,
+ source_objects=[f"{safe_name(SOURCE_MULTIPLE_TYPES)}.*.json"],
+ source_format="NEWLINE_DELIMITED_JSON",
+ destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}",
+ schema_object=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json",
+ )
+ # [END howto_operator_create_external_table_multiple_types]
+
+ read_data_from_gcs_multiple_types = BigQueryExecuteQueryOperator(
+ task_id="read_data_from_gcs_multiple_types",
+ sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_MULTIPLE_TYPES)}`",
+ use_legacy_sql=False,
+ )
+
+ # [START howto_operator_trino_to_gcs_many_chunks]
+ trino_to_gcs_many_chunks = TrinoToGCSOperator(
+ task_id="trino_to_gcs_many_chunks",
+ sql=f"select * from {SOURCE_CUSTOMER_TABLE}",
+ bucket=GCS_BUCKET,
+ filename=f"{safe_name(SOURCE_CUSTOMER_TABLE)}.{{}}.json",
+ schema_filename=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json",
+ approx_max_file_size_bytes=10_000_000,
+ gzip=False,
+ )
+ # [END howto_operator_trino_to_gcs_many_chunks]
+
+ create_external_table_many_chunks = BigQueryCreateExternalTableOperator(
+ task_id="create_external_table_many_chunks",
+ bucket=GCS_BUCKET,
+ source_objects=[f"{safe_name(SOURCE_CUSTOMER_TABLE)}.*.json"],
+ source_format="NEWLINE_DELIMITED_JSON",
+ destination_project_dataset_table=f"{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}",
+ schema_object=f"{safe_name(SOURCE_CUSTOMER_TABLE)}-schema.json",
+ )
+
+ # [START howto_operator_read_data_from_gcs_many_chunks]
+ read_data_from_gcs_many_chunks = BigQueryExecuteQueryOperator(
+ task_id="read_data_from_gcs_many_chunks",
+ sql=f"SELECT COUNT(*) FROM `{GCP_PROJECT_ID}.{DATASET_NAME}.{safe_name(SOURCE_CUSTOMER_TABLE)}`",
+ use_legacy_sql=False,
+ )
+ # [END howto_operator_read_data_from_gcs_many_chunks]
+
+ # [START howto_operator_trino_to_gcs_csv]
+ trino_to_gcs_csv = TrinoToGCSOperator(
+ task_id="trino_to_gcs_csv",
+ sql=f"select * from {SOURCE_MULTIPLE_TYPES}",
+ bucket=GCS_BUCKET,
+ filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}.{{}}.csv",
+ schema_filename=f"{safe_name(SOURCE_MULTIPLE_TYPES)}-schema.json",
+ export_format="csv",
+ )
+ # [END howto_operator_trino_to_gcs_csv]
+
+ create_dataset >> trino_to_gcs_basic
+ create_dataset >> trino_to_gcs_multiple_types
+ create_dataset >> trino_to_gcs_many_chunks
+ create_dataset >> trino_to_gcs_csv
+
+ trino_to_gcs_multiple_types >> create_external_table_multiple_types >> read_data_from_gcs_multiple_types
+ trino_to_gcs_many_chunks >> create_external_table_many_chunks >> read_data_from_gcs_many_chunks
+
+ trino_to_gcs_basic >> delete_dataset
+ trino_to_gcs_csv >> delete_dataset
+ read_data_from_gcs_multiple_types >> delete_dataset
+ read_data_from_gcs_many_chunks >> delete_dataset
diff --git a/airflow/providers/google/cloud/transfers/trino_to_gcs.py b/airflow/providers/google/cloud/transfers/trino_to_gcs.py
new file mode 100644
index 0000000000000..e2f2306bc80cc
--- /dev/null
+++ b/airflow/providers/google/cloud/transfers/trino_to_gcs.py
@@ -0,0 +1,210 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Any, Dict, List, Tuple
+
+from trino.client import TrinoResult
+from trino.dbapi import Cursor as TrinoCursor
+
+from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
+from airflow.providers.trino.hooks.trino import TrinoHook
+from airflow.utils.decorators import apply_defaults
+
+
+class _TrinoToGCSTrinoCursorAdapter:
+ """
+ An adapter that adds additional feature to the Trino cursor.
+
+ The implementation of cursor in the trino library is not sufficient.
+ The following changes have been made:
+
+ * The poke mechanism for row. You can look at the next row without consuming it.
+ * The description attribute is available before reading the first row. Thanks to the poke mechanism.
+ * the iterator interface has been implemented.
+
+ A detailed description of the class methods is available in
+ `PEP-249 `__.
+ """
+
+ def __init__(self, cursor: TrinoCursor):
+ self.cursor: TrinoCursor = cursor
+ self.rows: List[Any] = []
+ self.initialized: bool = False
+
+ @property
+ def description(self) -> List[Tuple]:
+ """
+ This read-only attribute is a sequence of 7-item sequences.
+
+ Each of these sequences contains information describing one result column:
+
+ * ``name``
+ * ``type_code``
+ * ``display_size``
+ * ``internal_size``
+ * ``precision``
+ * ``scale``
+ * ``null_ok``
+
+ The first two items (``name`` and ``type_code``) are mandatory, the other
+ five are optional and are set to None if no meaningful values can be provided.
+ """
+ if not self.initialized:
+ # Peek for first row to load description.
+ self.peekone()
+ return self.cursor.description
+
+ @property
+ def rowcount(self) -> int:
+ """The read-only attribute specifies the number of rows"""
+ return self.cursor.rowcount
+
+ def close(self) -> None:
+ """Close the cursor now"""
+ self.cursor.close()
+
+ def execute(self, *args, **kwargs) -> TrinoResult:
+ """Prepare and execute a database operation (query or command)."""
+ self.initialized = False
+ self.rows = []
+ return self.cursor.execute(*args, **kwargs)
+
+ def executemany(self, *args, **kwargs):
+ """
+ Prepare a database operation (query or command) and then execute it against all parameter
+ sequences or mappings found in the sequence seq_of_parameters.
+ """
+ self.initialized = False
+ self.rows = []
+ return self.cursor.executemany(*args, **kwargs)
+
+ def peekone(self) -> Any:
+ """Return the next row without consuming it."""
+ self.initialized = True
+ element = self.cursor.fetchone()
+ self.rows.insert(0, element)
+ return element
+
+ def fetchone(self) -> Any:
+ """
+ Fetch the next row of a query result set, returning a single sequence, or
+ ``None`` when no more data is available.
+ """
+ if self.rows:
+ return self.rows.pop(0)
+ return self.cursor.fetchone()
+
+ def fetchmany(self, size=None) -> list:
+ """
+ Fetch the next set of rows of a query result, returning a sequence of sequences
+ (e.g. a list of tuples). An empty sequence is returned when no more rows are available.
+ """
+ if size is None:
+ size = self.cursor.arraysize
+
+ result = []
+ for _ in range(size):
+ row = self.fetchone()
+ if row is None:
+ break
+ result.append(row)
+
+ return result
+
+ def __next__(self) -> Any:
+ """
+ Return the next row from the currently executing SQL statement using the same semantics as
+ ``.fetchone()``. A ``StopIteration`` exception is raised when the result set is exhausted.
+ :return:
+ """
+ result = self.fetchone()
+ if result is None:
+ raise StopIteration()
+ return result
+
+ def __iter__(self) -> "_TrinoToGCSTrinoCursorAdapter":
+ """Return self to make cursors compatible to the iteration protocol"""
+ return self
+
+
+class TrinoToGCSOperator(BaseSQLToGCSOperator):
+ """Copy data from TrinoDB to Google Cloud Storage in JSON or CSV format.
+
+ :param trino_conn_id: Reference to a specific Trino hook.
+ :type trino_conn_id: str
+ """
+
+ ui_color = "#a0e08c"
+
+ type_map = {
+ "BOOLEAN": "BOOL",
+ "TINYINT": "INT64",
+ "SMALLINT": "INT64",
+ "INTEGER": "INT64",
+ "BIGINT": "INT64",
+ "REAL": "FLOAT64",
+ "DOUBLE": "FLOAT64",
+ "DECIMAL": "NUMERIC",
+ "VARCHAR": "STRING",
+ "CHAR": "STRING",
+ "VARBINARY": "BYTES",
+ "JSON": "STRING",
+ "DATE": "DATE",
+ "TIME": "TIME",
+ # BigQuery don't time with timezone native.
+ "TIME WITH TIME ZONE": "STRING",
+ "TIMESTAMP": "TIMESTAMP",
+ # BigQuery supports a narrow range of time zones during import.
+ # You should use TIMESTAMP function, if you want have TIMESTAMP type
+ "TIMESTAMP WITH TIME ZONE": "STRING",
+ "IPADDRESS": "STRING",
+ "UUID": "STRING",
+ }
+
+ @apply_defaults
+ def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs):
+ super().__init__(**kwargs)
+ self.trino_conn_id = trino_conn_id
+
+ def query(self):
+ """Queries trino and returns a cursor to the results."""
+ trino = TrinoHook(trino_conn_id=self.trino_conn_id)
+ conn = trino.get_conn()
+ cursor = conn.cursor()
+ self.log.info("Executing: %s", self.sql)
+ cursor.execute(self.sql)
+ return _TrinoToGCSTrinoCursorAdapter(cursor)
+
+ def field_to_bigquery(self, field) -> Dict[str, str]:
+ """Convert trino field type to BigQuery field type."""
+ clear_field_type = field[1].upper()
+ # remove type argument e.g. DECIMAL(2, 10) => DECIMAL
+ clear_field_type, _, _ = clear_field_type.partition("(")
+ new_field_type = self.type_map.get(clear_field_type, "STRING")
+
+ return {"name": field[0], "type": new_field_type}
+
+ def convert_type(self, value, schema_type):
+ """
+ Do nothing. Trino uses JSON on the transport layer, so types are simple.
+
+ :param value: Trino column value
+ :type value: Any
+ :param schema_type: BigQuery data type
+ :type schema_type: str
+ """
+ return value
diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml
index 49170834fa698..6210f87869747 100644
--- a/airflow/providers/google/provider.yaml
+++ b/airflow/providers/google/provider.yaml
@@ -629,6 +629,10 @@ transfers:
target-integration-name: Google Cloud Storage (GCS)
how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/presto_to_gcs.rst
python-module: airflow.providers.google.cloud.transfers.presto_to_gcs
+ - source-integration-name: Trino
+ target-integration-name: Google Cloud Storage (GCS)
+ how-to-guide: /docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst
+ python-module: airflow.providers.google.cloud.transfers.trino_to_gcs
- source-integration-name: SQL
target-integration-name: Google Cloud Storage (GCS)
python-module: airflow.providers.google.cloud.transfers.sql_to_gcs
diff --git a/airflow/providers/mysql/provider.yaml b/airflow/providers/mysql/provider.yaml
index a9b408fcf0174..3b40f50f67bd4 100644
--- a/airflow/providers/mysql/provider.yaml
+++ b/airflow/providers/mysql/provider.yaml
@@ -52,9 +52,12 @@ transfers:
- source-integration-name: Amazon Simple Storage Service (S3)
target-integration-name: MySQL
python-module: airflow.providers.mysql.transfers.s3_to_mysql
- - source-integration-name: Snowflake
+ - source-integration-name: Presto
target-integration-name: MySQL
python-module: airflow.providers.mysql.transfers.presto_to_mysql
+ - source-integration-name: Trino
+ target-integration-name: MySQL
+ python-module: airflow.providers.mysql.transfers.trino_to_mysql
hook-class-names:
- airflow.providers.mysql.hooks.mysql.MySqlHook
diff --git a/airflow/providers/mysql/transfers/trino_to_mysql.py b/airflow/providers/mysql/transfers/trino_to_mysql.py
new file mode 100644
index 0000000000000..b97550e116d8f
--- /dev/null
+++ b/airflow/providers/mysql/transfers/trino_to_mysql.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Dict, Optional
+
+from airflow.models import BaseOperator
+from airflow.providers.mysql.hooks.mysql import MySqlHook
+from airflow.providers.trino.hooks.trino import TrinoHook
+from airflow.utils.decorators import apply_defaults
+
+
+class TrinoToMySqlOperator(BaseOperator):
+ """
+ Moves data from Trino to MySQL, note that for now the data is loaded
+ into memory before being pushed to MySQL, so this operator should
+ be used for smallish amount of data.
+
+ :param sql: SQL query to execute against Trino. (templated)
+ :type sql: str
+ :param mysql_table: target MySQL table, use dot notation to target a
+ specific database. (templated)
+ :type mysql_table: str
+ :param mysql_conn_id: source mysql connection
+ :type mysql_conn_id: str
+ :param trino_conn_id: source trino connection
+ :type trino_conn_id: str
+ :param mysql_preoperator: sql statement to run against mysql prior to
+ import, typically use to truncate of delete in place
+ of the data coming in, allowing the task to be idempotent (running
+ the task twice won't double load data). (templated)
+ :type mysql_preoperator: str
+ """
+
+ template_fields = ('sql', 'mysql_table', 'mysql_preoperator')
+ template_ext = ('.sql',)
+ template_fields_renderers = {"mysql_preoperator": "sql"}
+ ui_color = '#a0e08c'
+
+ @apply_defaults
+ def __init__(
+ self,
+ *,
+ sql: str,
+ mysql_table: str,
+ trino_conn_id: str = 'trino_default',
+ mysql_conn_id: str = 'mysql_default',
+ mysql_preoperator: Optional[str] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.sql = sql
+ self.mysql_table = mysql_table
+ self.mysql_conn_id = mysql_conn_id
+ self.mysql_preoperator = mysql_preoperator
+ self.trino_conn_id = trino_conn_id
+
+ def execute(self, context: Dict) -> None:
+ trino = TrinoHook(trino_conn_id=self.trino_conn_id)
+ self.log.info("Extracting data from Trino: %s", self.sql)
+ results = trino.get_records(self.sql)
+
+ mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
+ if self.mysql_preoperator:
+ self.log.info("Running MySQL preoperator")
+ self.log.info(self.mysql_preoperator)
+ mysql.run(self.mysql_preoperator)
+
+ self.log.info("Inserting rows into MySQL")
+ mysql.insert_rows(table=self.mysql_table, rows=results)
diff --git a/airflow/providers/trino/CHANGELOG.rst b/airflow/providers/trino/CHANGELOG.rst
new file mode 100644
index 0000000000000..cef7dda80708a
--- /dev/null
+++ b/airflow/providers/trino/CHANGELOG.rst
@@ -0,0 +1,25 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Changelog
+---------
+
+1.0.0
+.....
+
+Initial version of the provider.
diff --git a/airflow/providers/trino/__init__.py b/airflow/providers/trino/__init__.py
new file mode 100644
index 0000000000000..217e5db960782
--- /dev/null
+++ b/airflow/providers/trino/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/trino/hooks/__init__.py b/airflow/providers/trino/hooks/__init__.py
new file mode 100644
index 0000000000000..217e5db960782
--- /dev/null
+++ b/airflow/providers/trino/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py
new file mode 100644
index 0000000000000..0914d04b32e4b
--- /dev/null
+++ b/airflow/providers/trino/hooks/trino.py
@@ -0,0 +1,191 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from typing import Any, Iterable, Optional
+
+import trino
+from trino.exceptions import DatabaseError
+from trino.transaction import IsolationLevel
+
+from airflow import AirflowException
+from airflow.configuration import conf
+from airflow.hooks.dbapi import DbApiHook
+from airflow.models import Connection
+
+
+class TrinoException(Exception):
+ """Trino exception"""
+
+
+def _boolify(value):
+ if isinstance(value, bool):
+ return value
+ if isinstance(value, str):
+ if value.lower() == 'false':
+ return False
+ elif value.lower() == 'true':
+ return True
+ return value
+
+
+class TrinoHook(DbApiHook):
+ """
+ Interact with Trino through trino package.
+
+ >>> ph = TrinoHook()
+ >>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
+ >>> ph.get_records(sql)
+ [[340698]]
+ """
+
+ conn_name_attr = 'trino_conn_id'
+ default_conn_name = 'trino_default'
+ conn_type = 'trino'
+ hook_name = 'Trino'
+
+ def get_conn(self) -> Connection:
+ """Returns a connection object"""
+ db = self.get_connection(
+ self.trino_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
+ )
+ extra = db.extra_dejson
+ auth = None
+ if db.password and extra.get('auth') == 'kerberos':
+ raise AirflowException("Kerberos authorization doesn't support password.")
+ elif db.password:
+ auth = trino.auth.BasicAuthentication(db.login, db.password)
+ elif extra.get('auth') == 'kerberos':
+ auth = trino.auth.KerberosAuthentication(
+ config=extra.get('kerberos__config', os.environ.get('KRB5_CONFIG')),
+ service_name=extra.get('kerberos__service_name'),
+ mutual_authentication=_boolify(extra.get('kerberos__mutual_authentication', False)),
+ force_preemptive=_boolify(extra.get('kerberos__force_preemptive', False)),
+ hostname_override=extra.get('kerberos__hostname_override'),
+ sanitize_mutual_error_response=_boolify(
+ extra.get('kerberos__sanitize_mutual_error_response', True)
+ ),
+ principal=extra.get('kerberos__principal', conf.get('kerberos', 'principal')),
+ delegate=_boolify(extra.get('kerberos__delegate', False)),
+ ca_bundle=extra.get('kerberos__ca_bundle'),
+ )
+
+ trino_conn = trino.dbapi.connect(
+ host=db.host,
+ port=db.port,
+ user=db.login,
+ source=db.extra_dejson.get('source', 'airflow'),
+ http_scheme=db.extra_dejson.get('protocol', 'http'),
+ catalog=db.extra_dejson.get('catalog', 'hive'),
+ schema=db.schema,
+ auth=auth,
+ isolation_level=self.get_isolation_level(), # type: ignore[func-returns-value]
+ )
+ if extra.get('verify') is not None:
+ # Unfortunately verify parameter is available via public API.
+ # The PR is merged in the trino library, but has not been released.
+ # See: https://github.com/trinodb/trino-python-client/pull/31
+ trino_conn._http_session.verify = _boolify(extra['verify']) # pylint: disable=protected-access
+
+ return trino_conn
+
+ def get_isolation_level(self) -> Any:
+ """Returns an isolation level"""
+ db = self.get_connection(
+ self.trino_conn_id # type: ignore[attr-defined] # pylint: disable=no-member
+ )
+ isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
+ return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)
+
+ @staticmethod
+ def _strip_sql(sql: str) -> str:
+ return sql.strip().rstrip(';')
+
+ def get_records(self, hql, parameters: Optional[dict] = None):
+ """Get a set of records from Trino"""
+ try:
+ return super().get_records(self._strip_sql(hql), parameters)
+ except DatabaseError as e:
+ raise TrinoException(e)
+
+ def get_first(self, hql: str, parameters: Optional[dict] = None) -> Any:
+ """Returns only the first row, regardless of how many rows the query returns."""
+ try:
+ return super().get_first(self._strip_sql(hql), parameters)
+ except DatabaseError as e:
+ raise TrinoException(e)
+
+ def get_pandas_df(self, hql, parameters=None, **kwargs):
+ """Get a pandas dataframe from a sql query."""
+ import pandas
+
+ cursor = self.get_cursor()
+ try:
+ cursor.execute(self._strip_sql(hql), parameters)
+ data = cursor.fetchall()
+ except DatabaseError as e:
+ raise TrinoException(e)
+ column_descriptions = cursor.description
+ if data:
+ df = pandas.DataFrame(data, **kwargs)
+ df.columns = [c[0] for c in column_descriptions]
+ else:
+ df = pandas.DataFrame(**kwargs)
+ return df
+
+ def run(
+ self,
+ hql,
+ autocommit: bool = False,
+ parameters: Optional[dict] = None,
+ ) -> None:
+ """Execute the statement against Trino. Can be used to create views."""
+ return super().run(sql=self._strip_sql(hql), parameters=parameters)
+
+ def insert_rows(
+ self,
+ table: str,
+ rows: Iterable[tuple],
+ target_fields: Optional[Iterable[str]] = None,
+ commit_every: int = 0,
+ replace: bool = False,
+ **kwargs,
+ ) -> None:
+ """
+ A generic way to insert a set of tuples into a table.
+
+ :param table: Name of the target table
+ :type table: str
+ :param rows: The rows to insert into the table
+ :type rows: iterable of tuples
+ :param target_fields: The names of the columns to fill in the table
+ :type target_fields: iterable of strings
+ :param commit_every: The maximum number of rows to insert in one
+ transaction. Set to 0 to insert all rows in one transaction.
+ :type commit_every: int
+ :param replace: Whether to replace instead of insert
+ :type replace: bool
+ """
+ if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT:
+ self.log.info(
+ 'Transactions are not enable in trino connection. '
+ 'Please use the isolation_level property to enable it. '
+ 'Falling back to insert all rows in one transaction.'
+ )
+ commit_every = 0
+
+ super().insert_rows(table, rows, target_fields, commit_every)
diff --git a/airflow/providers/trino/provider.yaml b/airflow/providers/trino/provider.yaml
new file mode 100644
index 0000000000000..a59aaae6abbe1
--- /dev/null
+++ b/airflow/providers/trino/provider.yaml
@@ -0,0 +1,39 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-trino
+name: Trino
+description: |
+ `Trino `__
+
+versions:
+ - 1.0.0
+
+integrations:
+ - integration-name: Trino
+ external-doc-url: https://trino.io/docs/
+ logo: /integration-logos/trino/trino-og.png
+ tags: [software]
+
+hooks:
+ - integration-name: Trino
+ python-modules:
+ - airflow.providers.trino.hooks.trino
+
+hook-class-names:
+ - airflow.providers.trino.hooks.trino.TrinoHook
diff --git a/airflow/sensors/sql.py b/airflow/sensors/sql.py
index 923af6c75d062..efe88fff06648 100644
--- a/airflow/sensors/sql.py
+++ b/airflow/sensors/sql.py
@@ -84,6 +84,7 @@ def _get_hook(self):
'presto',
'snowflake',
'sqlite',
+ 'trino',
'vertica',
}
if conn.conn_type not in allowed_conn_type:
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 46ace4aa7b523..0a30901fce230 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -516,6 +516,16 @@ def create_default_connections(session=None):
),
session,
)
+ merge_conn(
+ Connection(
+ conn_id="trino_default",
+ conn_type="trino",
+ host="localhost",
+ schema="hive",
+ port=3400,
+ ),
+ session,
+ )
merge_conn(
Connection(
conn_id="vertica_default",
diff --git a/breeze b/breeze
index c85a5acf9ca2e..302cebbdbb285 100755
--- a/breeze
+++ b/breeze
@@ -819,7 +819,7 @@ function breeze::parse_arguments() {
else
INTEGRATIONS+=("${INTEGRATION}")
fi
- if [[ " ${INTEGRATIONS[*]} " =~ " presto " ]]; then
+ if [[ " ${INTEGRATIONS[*]} " =~ " trino " ]]; then
INTEGRATIONS+=("kerberos");
fi
echo
diff --git a/breeze-complete b/breeze-complete
index daa03bce79b81..2004a1a9ae185 100644
--- a/breeze-complete
+++ b/breeze-complete
@@ -25,7 +25,7 @@
_breeze_allowed_python_major_minor_versions="2.7 3.5 3.6 3.7 3.8"
_breeze_allowed_backends="sqlite mysql postgres"
-_breeze_allowed_integrations="cassandra kerberos mongo openldap pinot presto rabbitmq redis statsd all"
+_breeze_allowed_integrations="cassandra kerberos mongo openldap pinot rabbitmq redis statsd trino all"
_breeze_allowed_generate_constraints_modes="source-providers pypi-providers no-providers"
# registrys is good here even if it is not correct english. We are adding s automatically to all variables
_breeze_allowed_github_registrys="docker.pkg.github.com ghcr.io"
diff --git a/docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst b/docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst
new file mode 100644
index 0000000000000..29dc5405e7484
--- /dev/null
+++ b/docs/apache-airflow-providers-google/operators/transfer/trino_to_gcs.rst
@@ -0,0 +1,142 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+Trino to Google Cloud Storage Transfer Operator
+===============================================
+
+`Trino `__ is an open source, fast, distributed SQL query engine for running interactive
+analytic queries against data sources of all sizes ranging from gigabytes to petabytes. Trino allows
+querying data where it lives, including Hive, Cassandra, relational databases or even proprietary data stores.
+A single Trino query can combine data from multiple sources, allowing for analytics across your entire
+organization.
+
+`Google Cloud Storage `__ allows world-wide storage and retrieval of
+any amount of data at any time. You can use it to store backup and
+`archive data `__ as well
+as a `data source for BigQuery `__.
+
+
+Data transfer
+-------------
+
+Transfer files between Trino and Google Storage is performed with the
+:class:`~airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoToGCSOperator` operator.
+
+This operator has 3 required parameters:
+
+* ``sql`` - The SQL to execute.
+* ``bucket`` - The bucket to upload to.
+* ``filename`` - The filename to use as the object name when uploading to Google Cloud Storage.
+ A ``{}`` should be specified in the filename to allow the operator to inject file
+ numbers in cases where the file is split due to size.
+
+All parameters are described in the reference documentation - :class:`~airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoToGCSOperator`.
+
+An example operator call might look like this:
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_trino_to_gcs_basic]
+ :end-before: [END howto_operator_trino_to_gcs_basic]
+
+Choice of data format
+^^^^^^^^^^^^^^^^^^^^^
+
+The operator supports two output formats:
+
+* ``json`` - JSON Lines (default)
+* ``csv``
+
+You can specify these options by the ``export_format`` parameter.
+
+If you want a CSV file to be created, your operator call might look like this:
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_trino_to_gcs_csv]
+ :end-before: [END howto_operator_trino_to_gcs_csv]
+
+Generating BigQuery schema
+^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+If you set ``schema_filename`` parameter, a ``.json`` file containing the BigQuery schema fields for the table
+will be dumped from the database and upload to the bucket.
+
+If you want to create a schema file, then an example operator call might look like this:
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_trino_to_gcs_multiple_types]
+ :end-before: [END howto_operator_trino_to_gcs_multiple_types]
+
+For more information about the BigQuery schema, please look at
+`Specifying schema `__ in the Big Query documentation.
+
+Division of the result into multiple files
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+This operator supports the ability to split large result into multiple files. The ``approx_max_file_size_bytes``
+parameters allows developers to specify the file size of the splits. By default, the file has no more
+than 1 900 000 000 bytes (1900 MB)
+
+Check `Quotas & limits in Google Cloud Storage `__ to see the
+maximum allowed file size for a single object.
+
+If you want to create 10 MB files, your code might look like this:
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_read_data_from_gcs_many_chunks]
+ :end-before: [END howto_operator_read_data_from_gcs_many_chunks]
+
+Querying data using the BigQuery
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The data available in Google Cloud Storage can be used by BigQuery. You can load data to BigQuery or
+refer in queries directly to GCS data. For information about the loading data to the BigQuery, please look at
+`Introduction to loading data from Cloud Storage `__
+in the BigQuery documentation. For information about the querying GCS data, please look at
+`Querying Cloud Storage data `__ in
+the BigQuery documentation.
+
+Airflow also has numerous operators that allow you to create the use of BigQuery.
+For example, if you want to create an external table that allows you to create queries that
+read data directly from GCS, then you can use :class:`~airflow.providers.google.cloud.operators.bigquery.BigQueryCreateExternalTableOperator`.
+Using this operator looks like this:
+
+.. exampleinclude:: /../../airflow/providers/google/cloud/example_dags/example_trino_to_gcs.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_create_external_table_multiple_types]
+ :end-before: [END howto_operator_create_external_table_multiple_types]
+
+For more information about the Airflow and BigQuery integration, please look at
+the Python API Reference - :class:`~airflow.providers.google.cloud.operators.bigquery`.
+
+Reference
+^^^^^^^^^
+
+For further information, look at:
+
+* `Trino Documentation `__
+
+* `Google Cloud Storage Documentation `__
diff --git a/docs/apache-airflow-providers-trino/commits.rst b/docs/apache-airflow-providers-trino/commits.rst
new file mode 100644
index 0000000000000..5f0341d81e984
--- /dev/null
+++ b/docs/apache-airflow-providers-trino/commits.rst
@@ -0,0 +1,26 @@
+
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+Package apache-airflow-providers-trino
+------------------------------------------------------
+
+`Trino `__
+
+
+This is detailed commit list of changes for versions provider package: ``trino``.
+For high-level changelog, see :doc:`package information including changelog `.
diff --git a/docs/apache-airflow-providers-trino/index.rst b/docs/apache-airflow-providers-trino/index.rst
new file mode 100644
index 0000000000000..e74c7d62e7d23
--- /dev/null
+++ b/docs/apache-airflow-providers-trino/index.rst
@@ -0,0 +1,43 @@
+
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+``apache-airflow-providers-trino``
+===================================
+
+Content
+-------
+
+.. toctree::
+ :maxdepth: 1
+ :caption: References
+
+ Python API <_api/airflow/providers/trino/index>
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Resources
+
+ PyPI Repository
+
+.. THE REMINDER OF THE FILE IS AUTOMATICALLY GENERATED. IT WILL BE OVERWRITTEN AT RELEASE TIME!
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Commits
+
+ Detailed list of commits
diff --git a/docs/apache-airflow/extra-packages-ref.rst b/docs/apache-airflow/extra-packages-ref.rst
index 174edec59347d..b902868f5c164 100644
--- a/docs/apache-airflow/extra-packages-ref.rst
+++ b/docs/apache-airflow/extra-packages-ref.rst
@@ -54,7 +54,7 @@ python dependencies for the provided package.
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| google_auth | ``pip install 'apache-airflow[google_auth]'`` | Google auth backend |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
-| kerberos | ``pip install 'apache-airflow[kerberos]'`` | Kerberos integration for Kerberized services (Hadoop, Presto) |
+| kerberos | ``pip install 'apache-airflow[kerberos]'`` | Kerberos integration for Kerberized services (Hadoop, Presto, Trino) |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
| ldap | ``pip install 'apache-airflow[ldap]'`` | LDAP authentication for users |
+---------------------+-----------------------------------------------------+----------------------------------------------------------------------------+
@@ -235,6 +235,8 @@ Those are extras that add dependencies needed for integration with other softwar
+---------------------+-----------------------------------------------------+-------------------------------------------+
| singularity | ``pip install 'apache-airflow[singularity]'`` | Singularity container operator |
+---------------------+-----------------------------------------------------+-------------------------------------------+
+| trino | ``pip install 'apache-airflow[trino]'`` | All Trino related operators & hooks |
++---------------------+-----------------------------------------------------+-------------------------------------------+
Other extras
diff --git a/docs/exts/docs_build/errors.py b/docs/exts/docs_build/errors.py
index 1a2ae0698073a..3fe9f36d810b3 100644
--- a/docs/exts/docs_build/errors.py
+++ b/docs/exts/docs_build/errors.py
@@ -69,7 +69,7 @@ def display_errors_summary(build_errors: Dict[str, List[DocBuildError]]) -> None
console.print("-" * 30, f"[red]Error {warning_no:3}[/]", "-" * 20)
console.print(error.message)
console.print()
- if error.file_path and error.file_path != "" and error.line_no:
+ if error.file_path and not error.file_path.endswith("") and error.line_no:
console.print(
f"File path: {os.path.relpath(error.file_path, start=DOCS_DIR)} ({error.line_no})"
)
diff --git a/docs/exts/docs_build/spelling_checks.py b/docs/exts/docs_build/spelling_checks.py
index f0b272267b535..24ce3f170584f 100644
--- a/docs/exts/docs_build/spelling_checks.py
+++ b/docs/exts/docs_build/spelling_checks.py
@@ -180,6 +180,6 @@ def _display_error(error: SpellingError):
console.print(f"Suggested Spelling: '{error.suggestion}'")
if error.context_line:
console.print(f"Line with Error: '{error.context_line}'")
- if error.line_no:
+ if error.file_path and not error.file_path.endswith("") and error.line_no:
console.print(f"Line Number: {error.line_no}")
console.print(prepare_code_snippet(error.file_path, error.line_no))
diff --git a/docs/integration-logos/trino/trino-og.png b/docs/integration-logos/trino/trino-og.png
new file mode 100644
index 0000000000000..55bedf93dd346
Binary files /dev/null and b/docs/integration-logos/trino/trino-og.png differ
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index f203300c76a50..70c6fb32f0a1d 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1352,6 +1352,7 @@ tooltips
traceback
tracebacks
travis
+trino
trojan
tsv
ttl
diff --git a/scripts/ci/docker-compose/integration-kerberos.yml b/scripts/ci/docker-compose/integration-kerberos.yml
index d157bd684c82d..95fc8c9c79e2f 100644
--- a/scripts/ci/docker-compose/integration-kerberos.yml
+++ b/scripts/ci/docker-compose/integration-kerberos.yml
@@ -36,7 +36,7 @@ services:
/opt/kerberos-utils/create_client.sh bob bob /root/kerberos-keytabs/airflow.keytab;
/opt/kerberos-utils/create_service.sh krb5-machine-example-com airflow
/root/kerberos-keytabs/airflow.keytab;
- /opt/kerberos-utils/create_service.sh presto HTTP /root/kerberos-keytabs/presto.keytab;
+ /opt/kerberos-utils/create_service.sh trino HTTP /root/kerberos-keytabs/trino.keytab;
healthcheck:
test: |-
python -c "
diff --git a/scripts/ci/docker-compose/integration-redis.yml b/scripts/ci/docker-compose/integration-redis.yml
index ab353d267ebe8..3cdf68caf18b9 100644
--- a/scripts/ci/docker-compose/integration-redis.yml
+++ b/scripts/ci/docker-compose/integration-redis.yml
@@ -21,7 +21,7 @@ services:
image: redis:5.0.1
volumes:
- /dev/urandom:/dev/random # Required to get non-blocking entropy source
- - redis-db-volume:/data/presto
+ - redis-db-volume:/data/redis
ports:
- "${REDIS_HOST_PORT}:6379"
healthcheck:
diff --git a/scripts/ci/docker-compose/integration-presto.yml b/scripts/ci/docker-compose/integration-trino.yml
similarity index 81%
rename from scripts/ci/docker-compose/integration-presto.yml
rename to scripts/ci/docker-compose/integration-trino.yml
index 7fce2069b31c9..3f420fb61ca17 100644
--- a/scripts/ci/docker-compose/integration-presto.yml
+++ b/scripts/ci/docker-compose/integration-trino.yml
@@ -17,10 +17,10 @@
---
version: "2.2"
services:
- presto:
- image: apache/airflow:presto-2020.10.08
- container_name: presto
- hostname: presto
+ trino:
+ image: apache/airflow:trino-2021.04.04
+ container_name: trino
+ hostname: trino
domainname: example.com
networks:
@@ -40,19 +40,19 @@ services:
volumes:
- /dev/urandom:/dev/random # Required to get non-blocking entropy source
- ../dockerfiles/krb5-kdc-server/krb5.conf:/etc/krb5.conf:ro
- - presto-db-volume:/data/presto
- - kerberos-keytabs:/home/presto/kerberos-keytabs
+ - trino-db-volume:/data/trino
+ - kerberos-keytabs:/home/trino/kerberos-keytabs
environment:
- KRB5_CONFIG=/etc/krb5.conf
- KRB5_TRACE=/dev/stderr
- - KRB5_KTNAME=/home/presto/kerberos-keytabs/presto.keytab
+ - KRB5_KTNAME=/home/trino/kerberos-keytabs/trino.keytab
airflow:
environment:
- - INTEGRATION_PRESTO=true
+ - INTEGRATION_TRINO=true
depends_on:
- presto:
+ trino:
condition: service_healthy
volumes:
- presto-db-volume:
+ trino-db-volume:
diff --git a/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh b/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh
index 30161a3f6c5c7..c92aeab70f629 100755
--- a/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh
+++ b/scripts/ci/dockerfiles/krb5-kdc-server/utils/create_service.sh
@@ -29,7 +29,7 @@ Usage: ${CMDNAME}
Creates an account for the service.
The service name is combined with the domain to create an principal name. If your service is named
-\"presto\" a principal \"presto.example.com\" will be created.
+\"trino\" a principal \"trino.example.com\" will be created.
The protocol can have any value, but it must be identical in the server and client configuration.
For example: HTTP.
diff --git a/scripts/ci/dockerfiles/presto/Dockerfile b/scripts/ci/dockerfiles/trino/Dockerfile
similarity index 78%
rename from scripts/ci/dockerfiles/presto/Dockerfile
rename to scripts/ci/dockerfiles/trino/Dockerfile
index 80ccbfd344527..080491f6a7f4e 100644
--- a/scripts/ci/dockerfiles/presto/Dockerfile
+++ b/scripts/ci/dockerfiles/trino/Dockerfile
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-ARG PRESTO_VERSION="330"
-FROM prestosql/presto:${PRESTO_VERSION}
+ARG TRINO_VERSION="354"
+FROM trinodb/trino:${TRINO_VERSION}
# Obtain root privileges
USER 0
@@ -23,16 +23,16 @@ USER 0
# Setup entrypoint
COPY entrypoint.sh /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]
-CMD ["/usr/lib/presto/bin/run-presto"]
+CMD ["/usr/lib/trino/bin/run-trino"]
# Expose HTTPS
EXPOSE 7778
-LABEL org.apache.airflow.component="presto"
-LABEL org.apache.airflow.presto.core.version="${PRESTO_VERSION}"
-LABEL org.apache.airflow.airflow_bats.version="${AIRFLOW_PRESTO_VERSION}"
+LABEL org.apache.airflow.component="trino"
+LABEL org.apache.airflow.trino.core.version="${TRINO_VERSION}"
+LABEL org.apache.airflow.airflow_trino.version="${AIRFLOW_TRINO_VERSION}"
LABEL org.apache.airflow.commit_sha="${COMMIT_SHA}"
LABEL maintainer="Apache Airflow Community "
# Restore user
-USER presto:presto
+USER trino:trino
diff --git a/scripts/ci/dockerfiles/presto/build_and_push.sh b/scripts/ci/dockerfiles/trino/build_and_push.sh
similarity index 79%
rename from scripts/ci/dockerfiles/presto/build_and_push.sh
rename to scripts/ci/dockerfiles/trino/build_and_push.sh
index d3cac47775b69..ea8a59d0742b0 100755
--- a/scripts/ci/dockerfiles/presto/build_and_push.sh
+++ b/scripts/ci/dockerfiles/trino/build_and_push.sh
@@ -21,24 +21,24 @@ DOCKERHUB_REPO=${DOCKERHUB_REPO:="airflow"}
readonly DOCKERHUB_USER
readonly DOCKERHUB_REPO
-PRESTO_VERSION="330"
-readonly PRESTO_VERSION
+TRINO_VERSION="354"
+readonly TRINO_VERSION
-AIRFLOW_PRESTO_VERSION="2020.10.08"
-readonly AIRFLOW_PRESTO_VERSION
+AIRFLOW_TRINO_VERSION="2021.04.04"
+readonly AIRFLOW_TRINO_VERSION
COMMIT_SHA=$(git rev-parse HEAD)
readonly COMMIT_SHA
cd "$( dirname "${BASH_SOURCE[0]}" )" || exit 1
-TAG="${DOCKERHUB_USER}/${DOCKERHUB_REPO}:presto-${AIRFLOW_PRESTO_VERSION}"
+TAG="${DOCKERHUB_USER}/${DOCKERHUB_REPO}:trino-${AIRFLOW_TRINO_VERSION}"
readonly TAG
docker build . \
--pull \
- --build-arg "PRESTO_VERSION=${PRESTO_VERSION}" \
- --build-arg "AIRFLOW_PRESTO_VERSION=${AIRFLOW_PRESTO_VERSION}" \
+ --build-arg "TRINO_VERSION=${TRINO_VERSION}" \
+ --build-arg "AIRFLOW_TRINO_VERSION=${AIRFLOW_TRINO_VERSION}" \
--build-arg "COMMIT_SHA=${COMMIT_SHA}" \
--tag "${TAG}"
diff --git a/scripts/ci/dockerfiles/presto/entrypoint.sh b/scripts/ci/dockerfiles/trino/entrypoint.sh
similarity index 73%
rename from scripts/ci/dockerfiles/presto/entrypoint.sh
rename to scripts/ci/dockerfiles/trino/entrypoint.sh
index 9c8d1130beeb2..314cc5a8ee166 100755
--- a/scripts/ci/dockerfiles/presto/entrypoint.sh
+++ b/scripts/ci/dockerfiles/trino/entrypoint.sh
@@ -32,7 +32,7 @@ function check_service {
RES=$?
set -e
if [[ ${RES} == 0 ]]; then
- echo "${COLOR_GREEN}OK. ${COLOR_RESET}"
+ echo "OK."
break
else
echo -n "."
@@ -58,27 +58,29 @@ function log() {
echo -e "\u001b[32m[$(date +'%Y-%m-%dT%H:%M:%S%z')]: $*\u001b[0m"
}
-if [ -f /tmp/presto-initiaalized ]; then
+if [ -f /tmp/trino-initialized ]; then
exec /bin/sh -c "$@"
fi
-PRESTO_CONFIG_FILE="/usr/lib/presto/default/etc/config.properties"
-JVM_CONFIG_FILE="/usr/lib/presto/default/etc/jvm.config"
+TRINO_CONFIG_FILE="/etc/trino/config.properties"
+JVM_CONFIG_FILE="/etc/trino/jvm.config"
log "Generate self-signed SSL certificate"
JKS_KEYSTORE_FILE=/tmp/ssl_keystore.jks
-JKS_KEYSTORE_PASS=presto
+JKS_KEYSTORE_PASS=trinodb
+keytool -delete --alias "trino-ssl" -keystore "${JKS_KEYSTORE_FILE}" -storepass "${JKS_KEYSTORE_PASS}" || true
+
keytool \
-genkeypair \
- -alias "presto-ssl" \
+ -alias "trino-ssl" \
-keyalg RSA \
-keystore "${JKS_KEYSTORE_FILE}" \
-validity 10000 \
-dname "cn=Unknown, ou=Unknown, o=Unknown, c=Unknown"\
-storepass "${JKS_KEYSTORE_PASS}"
-log "Set up SSL in ${PRESTO_CONFIG_FILE}"
-cat << EOF >> "${PRESTO_CONFIG_FILE}"
+log "Set up SSL in ${TRINO_CONFIG_FILE}"
+cat << EOF >> "${TRINO_CONFIG_FILE}"
http-server.https.enabled=true
http-server.https.port=7778
http-server.https.keystore.path=${JKS_KEYSTORE_FILE}
@@ -86,9 +88,18 @@ http-server.https.keystore.key=${JKS_KEYSTORE_PASS}
node.internal-address-source=FQDN
EOF
+log "Set up memory limits in ${TRINO_CONFIG_FILE}"
+cat << EOF >> "${TRINO_CONFIG_FILE}"
+memory.heap-headroom-per-node=128MB
+query.max-memory-per-node=512MB
+query.max-total-memory-per-node=512MB
+EOF
+
+sed -i "s/Xmx.*$/Xmx640M/" "${JVM_CONFIG_FILE}"
+
if [[ -n "${KRB5_CONFIG=}" ]]; then
- log "Set up Kerberos in ${PRESTO_CONFIG_FILE}"
- cat << EOF >> "${PRESTO_CONFIG_FILE}"
+ log "Set up Kerberos in ${TRINO_CONFIG_FILE}"
+ cat << EOF >> "${TRINO_CONFIG_FILE}"
http-server.https.enabled=true
http-server.https.port=7778
http-server.https.keystore.path=${JKS_KEYSTORE_FILE}
@@ -103,16 +114,18 @@ EOF
EOF
fi
-log "Waiting for keytab:${KRB5_KTNAME}"
-check_service "Keytab" "test -f ${KRB5_KTNAME}" 30
+if [[ -n "${KRB5_CONFIG=}" ]]; then
+ log "Waiting for keytab:${KRB5_KTNAME}"
+ check_service "Keytab" "test -f ${KRB5_KTNAME}" 30
+fi
-touch /tmp/presto-initiaalized
+touch /tmp/trino-initialized
echo "Config: ${JVM_CONFIG_FILE}"
cat "${JVM_CONFIG_FILE}"
-echo "Config: ${PRESTO_CONFIG_FILE}"
-cat "${PRESTO_CONFIG_FILE}"
+echo "Config: ${TRINO_CONFIG_FILE}"
+cat "${TRINO_CONFIG_FILE}"
log "Executing cmd: ${*}"
exec /bin/sh -c "${@}"
diff --git a/scripts/ci/libraries/_initialization.sh b/scripts/ci/libraries/_initialization.sh
index c6e2b41bbcff7..bbdf116269137 100644
--- a/scripts/ci/libraries/_initialization.sh
+++ b/scripts/ci/libraries/_initialization.sh
@@ -179,7 +179,7 @@ function initialization::initialize_dockerhub_variables() {
# Determine available integrations
function initialization::initialize_available_integrations() {
- export AVAILABLE_INTEGRATIONS="cassandra kerberos mongo openldap pinot presto rabbitmq redis statsd"
+ export AVAILABLE_INTEGRATIONS="cassandra kerberos mongo openldap pinot rabbitmq redis statsd trino"
}
# Needs to be declared outside of function for MacOS
diff --git a/scripts/in_container/check_environment.sh b/scripts/in_container/check_environment.sh
index 801477eede9c7..22c6fe58d2092 100755
--- a/scripts/in_container/check_environment.sh
+++ b/scripts/in_container/check_environment.sh
@@ -160,17 +160,17 @@ check_integration "MongoDB" "mongo" "run_nc mongo 27017" 50
check_integration "Redis" "redis" "run_nc redis 6379" 50
check_integration "Cassandra" "cassandra" "run_nc cassandra 9042" 50
check_integration "OpenLDAP" "openldap" "run_nc openldap 389" 50
-check_integration "Presto (HTTP)" "presto" "run_nc presto 8080" 50
-check_integration "Presto (HTTPS)" "presto" "run_nc presto 7778" 50
-check_integration "Presto (API)" "presto" \
- "curl --max-time 1 http://presto:8080/v1/info/ | grep '\"starting\":false'" 50
+check_integration "Trino (HTTP)" "trino" "run_nc trino 8080" 50
+check_integration "Trino (HTTPS)" "trino" "run_nc trino 7778" 50
+check_integration "Trino (API)" "trino" \
+ "curl --max-time 1 http://trino:8080/v1/info/ | grep '\"starting\":false'" 50
check_integration "Pinot (HTTP)" "pinot" "run_nc pinot 9000" 50
CMD="curl --max-time 1 -X GET 'http://pinot:9000/health' -H 'accept: text/plain' | grep OK"
-check_integration "Presto (Controller API)" "pinot" "${CMD}" 50
+check_integration "Pinot (Controller API)" "pinot" "${CMD}" 50
CMD="curl --max-time 1 -X GET 'http://pinot:9000/pinot-controller/admin' -H 'accept: text/plain' | grep GOOD"
-check_integration "Presto (Controller API)" "pinot" "${CMD}" 50
+check_integration "Pinot (Controller API)" "pinot" "${CMD}" 50
CMD="curl --max-time 1 -X GET 'http://pinot:8000/health' -H 'accept: text/plain' | grep OK"
-check_integration "Presto (Broker API)" "pinot" "${CMD}" 50
+check_integration "Pinot (Broker API)" "pinot" "${CMD}" 50
check_integration "RabbitMQ" "rabbitmq" "run_nc rabbitmq 5672" 50
echo "-----------------------------------------------------------------------------------------------"
diff --git a/scripts/in_container/run_install_and_test_provider_packages.sh b/scripts/in_container/run_install_and_test_provider_packages.sh
index 4eeb83c6efbf6..f6d31b6cbb4de 100755
--- a/scripts/in_container/run_install_and_test_provider_packages.sh
+++ b/scripts/in_container/run_install_and_test_provider_packages.sh
@@ -95,7 +95,7 @@ function discover_all_provider_packages() {
# Columns is to force it wider, so it doesn't wrap at 80 characters
COLUMNS=180 airflow providers list
- local expected_number_of_providers=65
+ local expected_number_of_providers=66
local actual_number_of_providers
actual_providers=$(airflow providers list --output yaml | grep package_name)
actual_number_of_providers=$(wc -l <<<"$actual_providers")
@@ -118,7 +118,7 @@ function discover_all_hooks() {
group_start "Listing available hooks via 'airflow providers hooks'"
COLUMNS=180 airflow providers hooks
- local expected_number_of_hooks=62
+ local expected_number_of_hooks=63
local actual_number_of_hooks
actual_number_of_hooks=$(airflow providers hooks --output table | grep -c "| apache" | xargs)
if [[ ${actual_number_of_hooks} != "${expected_number_of_hooks}" ]]; then
diff --git a/setup.py b/setup.py
index e1f1ebb826160..9441b6bdd96ff 100644
--- a/setup.py
+++ b/setup.py
@@ -454,6 +454,7 @@ def get_sphinx_theme_version() -> str:
telegram = [
'python-telegram-bot==13.0',
]
+trino = ['trino']
vertica = [
'vertica-python>=0.5.1',
]
@@ -584,6 +585,7 @@ def get_sphinx_theme_version() -> str:
'ssh': ssh,
'tableau': tableau,
'telegram': telegram,
+ 'trino': trino,
'vertica': vertica,
'yandex': yandex,
'zendesk': zendesk,
@@ -718,6 +720,7 @@ def add_extras_for_all_deprecated_aliases() -> None:
'neo4j',
'postgres',
'presto',
+ 'trino',
'vertica',
]
@@ -933,7 +936,9 @@ def add_all_provider_packages() -> None:
add_provider_packages_to_extra_requirements("devel_ci", ALL_PROVIDERS)
add_provider_packages_to_extra_requirements("devel_all", ALL_PROVIDERS)
add_provider_packages_to_extra_requirements("all_dbs", ALL_DB_PROVIDERS)
- add_provider_packages_to_extra_requirements("devel_hadoop", ["apache.hdfs", "apache.hive", "presto"])
+ add_provider_packages_to_extra_requirements(
+ "devel_hadoop", ["apache.hdfs", "apache.hive", "presto", "trino"]
+ )
class Develop(develop_orig):
diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py
index cf27941776889..136811dbe3e44 100644
--- a/tests/cli/commands/test_connection_command.py
+++ b/tests/cli/commands/test_connection_command.py
@@ -101,6 +101,10 @@ class TestCliListConnections(unittest.TestCase):
'sqlite_default',
'sqlite',
),
+ (
+ 'trino_default',
+ 'trino',
+ ),
(
'vertica_default',
'vertica',
diff --git a/tests/conftest.py b/tests/conftest.py
index d828642660174..bbf617d65ed1a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -128,7 +128,7 @@ def pytest_addoption(parser):
action="append",
metavar="INTEGRATIONS",
help="only run tests matching integration specified: "
- "[cassandra,kerberos,mongo,openldap,presto,rabbitmq,redis]. ",
+ "[cassandra,kerberos,mongo,openldap,rabbitmq,redis,statsd,trino]. ",
)
group.addoption(
"--backend",
diff --git a/tests/core/test_providers_manager.py b/tests/core/test_providers_manager.py
index 00880fd2e4c21..f38c8f663cdb3 100644
--- a/tests/core/test_providers_manager.py
+++ b/tests/core/test_providers_manager.py
@@ -83,6 +83,7 @@
'apache-airflow-providers-ssh',
'apache-airflow-providers-tableau',
'apache-airflow-providers-telegram',
+ 'apache-airflow-providers-trino',
'apache-airflow-providers-vertica',
'apache-airflow-providers-yandex',
'apache-airflow-providers-zendesk',
@@ -147,6 +148,7 @@
'sqoop',
'ssh',
'tableau',
+ 'trino',
'vault',
'vertica',
'wasb',
diff --git a/tests/providers/google/cloud/transfers/test_trino_to_gcs.py b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py
new file mode 100644
index 0000000000000..7cb6539a3d846
--- /dev/null
+++ b/tests/providers/google/cloud/transfers/test_trino_to_gcs.py
@@ -0,0 +1,331 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import unittest
+from unittest.mock import patch
+
+import pytest
+
+from airflow.providers.google.cloud.transfers.trino_to_gcs import TrinoToGCSOperator
+
+TASK_ID = "test-trino-to-gcs"
+TRINO_CONN_ID = "my-trino-conn"
+GCP_CONN_ID = "my-gcp-conn"
+IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
+SQL = "SELECT * FROM memory.default.test_multiple_types"
+BUCKET = "gs://test"
+FILENAME = "test_{}.ndjson"
+
+NDJSON_LINES = [
+ b'{"some_num": 42, "some_str": "mock_row_content_1"}\n',
+ b'{"some_num": 43, "some_str": "mock_row_content_2"}\n',
+ b'{"some_num": 44, "some_str": "mock_row_content_3"}\n',
+]
+CSV_LINES = [
+ b"some_num,some_str\r\n",
+ b"42,mock_row_content_1\r\n",
+ b"43,mock_row_content_2\r\n",
+ b"44,mock_row_content_3\r\n",
+]
+SCHEMA_FILENAME = "schema_test.json"
+SCHEMA_JSON = b'[{"name": "some_num", "type": "INT64"}, {"name": "some_str", "type": "STRING"}]'
+
+
+@pytest.mark.integration("trino")
+class TestTrinoToGCSOperator(unittest.TestCase):
+ def test_init(self):
+ """Test TrinoToGCSOperator instance is properly initialized."""
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+ assert op.task_id == TASK_ID
+ assert op.sql == SQL
+ assert op.bucket == BUCKET
+ assert op.filename == FILENAME
+ assert op.impersonation_chain == IMPERSONATION_CHAIN
+
+ @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
+ @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
+ def test_save_as_json(self, mock_gcs_hook, mock_trino_hook):
+ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+ assert BUCKET == bucket
+ assert FILENAME.format(0) == obj
+ assert "application/json" == mime_type
+ assert not gzip
+ with open(tmp_filename, "rb") as file:
+ assert b"".join(NDJSON_LINES) == file.read()
+
+ mock_gcs_hook.return_value.upload.side_effect = _assert_upload
+
+ mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor
+
+ mock_cursor.return_value.description = [
+ ("some_num", "INTEGER", None, None, None, None, None),
+ ("some_str", "VARCHAR", None, None, None, None, None),
+ ]
+
+ mock_cursor.return_value.fetchone.side_effect = [
+ [42, "mock_row_content_1"],
+ [43, "mock_row_content_2"],
+ [44, "mock_row_content_3"],
+ None,
+ ]
+
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ trino_conn_id=TRINO_CONN_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ op.execute(None)
+
+ mock_trino_hook.assert_called_once_with(trino_conn_id=TRINO_CONN_ID)
+ mock_gcs_hook.assert_called_once_with(
+ delegate_to=None,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ mock_gcs_hook.return_value.upload.assert_called()
+
+ @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
+ @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
+ def test_save_as_json_with_file_splitting(self, mock_gcs_hook, mock_trino_hook):
+ """Test that ndjson is split by approx_max_file_size_bytes param."""
+
+ expected_upload = {
+ FILENAME.format(0): b"".join(NDJSON_LINES[:2]),
+ FILENAME.format(1): NDJSON_LINES[2],
+ }
+
+ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+ assert BUCKET == bucket
+ assert "application/json" == mime_type
+ assert not gzip
+ with open(tmp_filename, "rb") as file:
+ assert expected_upload[obj] == file.read()
+
+ mock_gcs_hook.return_value.upload.side_effect = _assert_upload
+
+ mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor
+
+ mock_cursor.return_value.description = [
+ ("some_num", "INTEGER", None, None, None, None, None),
+ ("some_str", "VARCHAR(20)", None, None, None, None, None),
+ ]
+
+ mock_cursor.return_value.fetchone.side_effect = [
+ [42, "mock_row_content_1"],
+ [43, "mock_row_content_2"],
+ [44, "mock_row_content_3"],
+ None,
+ ]
+
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]),
+ )
+
+ op.execute(None)
+
+ mock_gcs_hook.return_value.upload.assert_called()
+
+ @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
+ @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
+ def test_save_as_json_with_schema_file(self, mock_gcs_hook, mock_trino_hook):
+ """Test writing schema files."""
+
+ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
+ if obj == SCHEMA_FILENAME:
+ with open(tmp_filename, "rb") as file:
+ assert SCHEMA_JSON == file.read()
+
+ mock_gcs_hook.return_value.upload.side_effect = _assert_upload
+
+ mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor
+
+ mock_cursor.return_value.description = [
+ ("some_num", "INTEGER", None, None, None, None, None),
+ ("some_str", "VARCHAR", None, None, None, None, None),
+ ]
+
+ mock_cursor.return_value.fetchone.side_effect = [
+ [42, "mock_row_content_1"],
+ [43, "mock_row_content_2"],
+ [44, "mock_row_content_3"],
+ None,
+ ]
+
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ schema_filename=SCHEMA_FILENAME,
+ export_format="csv",
+ trino_conn_id=TRINO_CONN_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ )
+ op.execute(None)
+
+ # once for the file and once for the schema
+ assert 2 == mock_gcs_hook.return_value.upload.call_count
+
+ @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
+ @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
+ def test_save_as_csv(self, mock_trino_hook, mock_gcs_hook):
+ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+ assert BUCKET == bucket
+ assert FILENAME.format(0) == obj
+ assert "text/csv" == mime_type
+ assert not gzip
+ with open(tmp_filename, "rb") as file:
+ assert b"".join(CSV_LINES) == file.read()
+
+ mock_gcs_hook.return_value.upload.side_effect = _assert_upload
+
+ mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor
+
+ mock_cursor.return_value.description = [
+ ("some_num", "INTEGER", None, None, None, None, None),
+ ("some_str", "VARCHAR", None, None, None, None, None),
+ ]
+
+ mock_cursor.return_value.fetchone.side_effect = [
+ [42, "mock_row_content_1"],
+ [43, "mock_row_content_2"],
+ [44, "mock_row_content_3"],
+ None,
+ ]
+
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ export_format="csv",
+ trino_conn_id=TRINO_CONN_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ op.execute(None)
+
+ mock_gcs_hook.return_value.upload.assert_called()
+
+ mock_trino_hook.assert_called_once_with(trino_conn_id=TRINO_CONN_ID)
+ mock_gcs_hook.assert_called_once_with(
+ delegate_to=None,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ )
+
+ @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
+ @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
+ def test_save_as_csv_with_file_splitting(self, mock_gcs_hook, mock_trino_hook):
+ """Test that csv is split by approx_max_file_size_bytes param."""
+
+ expected_upload = {
+ FILENAME.format(0): b"".join(CSV_LINES[:3]),
+ FILENAME.format(1): b"".join([CSV_LINES[0], CSV_LINES[3]]),
+ }
+
+ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip):
+ assert BUCKET == bucket
+ assert "text/csv" == mime_type
+ assert not gzip
+ with open(tmp_filename, "rb") as file:
+ assert expected_upload[obj] == file.read()
+
+ mock_gcs_hook.return_value.upload.side_effect = _assert_upload
+
+ mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor
+
+ mock_cursor.return_value.description = [
+ ("some_num", "INTEGER", None, None, None, None, None),
+ ("some_str", "VARCHAR(20)", None, None, None, None, None),
+ ]
+
+ mock_cursor.return_value.fetchone.side_effect = [
+ [42, "mock_row_content_1"],
+ [43, "mock_row_content_2"],
+ [44, "mock_row_content_3"],
+ None,
+ ]
+
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ approx_max_file_size_bytes=len(expected_upload[FILENAME.format(0)]),
+ export_format="csv",
+ )
+
+ op.execute(None)
+
+ mock_gcs_hook.return_value.upload.assert_called()
+
+ @patch("airflow.providers.google.cloud.transfers.trino_to_gcs.TrinoHook")
+ @patch("airflow.providers.google.cloud.transfers.sql_to_gcs.GCSHook")
+ def test_save_as_csv_with_schema_file(self, mock_gcs_hook, mock_trino_hook):
+ """Test writing schema files."""
+
+ def _assert_upload(bucket, obj, tmp_filename, mime_type, gzip): # pylint: disable=unused-argument
+ if obj == SCHEMA_FILENAME:
+ with open(tmp_filename, "rb") as file:
+ assert SCHEMA_JSON == file.read()
+
+ mock_gcs_hook.return_value.upload.side_effect = _assert_upload
+
+ mock_cursor = mock_trino_hook.return_value.get_conn.return_value.cursor
+
+ mock_cursor.return_value.description = [
+ ("some_num", "INTEGER", None, None, None, None, None),
+ ("some_str", "VARCHAR", None, None, None, None, None),
+ ]
+
+ mock_cursor.return_value.fetchone.side_effect = [
+ [42, "mock_row_content_1"],
+ [43, "mock_row_content_2"],
+ [44, "mock_row_content_3"],
+ None,
+ ]
+
+ op = TrinoToGCSOperator(
+ task_id=TASK_ID,
+ sql=SQL,
+ bucket=BUCKET,
+ filename=FILENAME,
+ schema_filename=SCHEMA_FILENAME,
+ export_format="csv",
+ )
+ op.execute(None)
+
+ # once for the file and once for the schema
+ assert 2 == mock_gcs_hook.return_value.upload.call_count
diff --git a/tests/providers/google/cloud/transfers/test_trino_to_gcs_system.py b/tests/providers/google/cloud/transfers/test_trino_to_gcs_system.py
new file mode 100644
index 0000000000000..00d5716556183
--- /dev/null
+++ b/tests/providers/google/cloud/transfers/test_trino_to_gcs_system.py
@@ -0,0 +1,169 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+from contextlib import closing, suppress
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.trino.hooks.trino import TrinoHook
+from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGQUERY_KEY, GCP_GCS_KEY
+from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
+
+try:
+ from airflow.utils.session import create_session
+except ImportError:
+ # This is a hack to import create_session from old destination and
+ # fool the pre-commit check that looks for old imports...
+ # TODO remove this once we don't need to test this on 1.10
+ import importlib
+
+ db_module = importlib.import_module("airflow.utils.db")
+ create_session = getattr(db_module, "create_session")
+
+
+GCS_BUCKET = os.environ.get("GCP_TRINO_TO_GCS_BUCKET_NAME", "test-trino-to-gcs-bucket")
+DATASET_NAME = os.environ.get("GCP_TRINO_TO_GCS_DATASET_NAME", "test_trino_to_gcs_dataset")
+
+CREATE_QUERY = """
+CREATE TABLE memory.default.test_multiple_types (
+ -- Boolean
+ z_boolean BOOLEAN,
+ -- Integers
+ z_tinyint TINYINT,
+ z_smallint SMALLINT,
+ z_integer INTEGER,
+ z_bigint BIGINT,
+ -- Floating-Point
+ z_real REAL,
+ z_double DOUBLE,
+ -- Fixed-Point
+ z_decimal DECIMAL(10,2),
+ -- String
+ z_varchar VARCHAR(20),
+ z_char CHAR(20),
+ z_varbinary VARBINARY,
+ z_json JSON,
+ -- Date and Time
+ z_date DATE,
+ z_time TIME,
+ z_time_with_time_zone TIME WITH TIME ZONE,
+ z_timestamp TIMESTAMP,
+ z_timestamp_with_time_zone TIMESTAMP WITH TIME ZONE,
+ -- Network Address
+ z_ipaddress_v4 IPADDRESS,
+ z_ipaddress_v6 IPADDRESS,
+ -- UUID
+ z_uuid UUID
+)
+"""
+
+LOAD_QUERY = """
+INSERT INTO memory.default.test_multiple_types VALUES(
+ -- Boolean
+ true, -- z_boolean BOOLEAN,
+ -- Integers
+ CAST(POW(2, 7 ) - 42 AS TINYINT), -- z_tinyint TINYINT,
+ CAST(POW(2, 15) - 42 AS SMALLINT), -- z_smallint SMALLINT,
+ CAST(POW(2, 31) - 42 AS INTEGER), -- z_integer INTEGER,
+ CAST(POW(2, 32) - 42 AS BIGINT) * 2, -- z_bigint BIGINT,
+ -- Floating-Point
+ REAL '42', -- z_real REAL,
+ DOUBLE '1.03e42', -- z_double DOUBLE,
+ -- Floating-Point
+ DECIMAL '1.1', -- z_decimal DECIMAL(10, 2),
+ -- String
+ U&'Hello winter \2603 !', -- z_vaarchar VARCHAR(20),
+ 'cat', -- z_char CHAR(20),
+ X'65683F', -- z_varbinary VARBINARY,
+ CAST('["A", 1, true]' AS JSON), -- z_json JSON,
+ -- Date and Time
+ DATE '2001-08-22', -- z_date DATE,
+ TIME '01:02:03.456', -- z_time TIME,
+ TIME '01:02:03.456 America/Los_Angeles', -- z_time_with_time_zone TIME WITH TIME ZONE,
+ TIMESTAMP '2001-08-22 03:04:05.321', -- z_timestamp TIMESTAMP,
+ TIMESTAMP '2001-08-22 03:04:05.321 America/Los_Angeles', -- z_timestamp_with_time_zone TIMESTAMP WITH TIME
+ -- ZONE,
+ -- Network Address
+ IPADDRESS '10.0.0.1', -- z_ipaddress_v4 IPADDRESS,
+ IPADDRESS '2001:db8::1', -- z_ipaddress_v6 IPADDRESS,
+ -- UUID
+ UUID '12151fd2-7586-11e9-8f9e-2a86e4085a59' -- z_uuid UUID
+)
+"""
+DELETE_QUERY = "DROP TABLE memory.default.test_multiple_types"
+
+
+@pytest.mark.integration("trino")
+class TrinoToGCSSystemTest(GoogleSystemTest):
+ @staticmethod
+ def init_connection():
+ with create_session() as session:
+ session.query(Connection).filter(Connection.conn_id == "trino_default").delete()
+ session.merge(
+ Connection(
+ conn_id="trino_default", conn_type="conn_type", host="trino", port=8080, login="airflow"
+ )
+ )
+
+ @staticmethod
+ def init_db():
+ hook = TrinoHook()
+ with hook.get_conn() as conn:
+ with closing(conn.cursor()) as cur:
+ cur.execute(CREATE_QUERY)
+ # Trino does not execute queries until the result is fetched. :-(
+ cur.fetchone()
+ cur.execute(LOAD_QUERY)
+ cur.fetchone()
+
+ @staticmethod
+ def drop_db():
+ hook = TrinoHook()
+ with hook.get_conn() as conn:
+ with closing(conn.cursor()) as cur:
+ cur.execute(DELETE_QUERY)
+ # Trino does not execute queries until the result is fetched. :-(
+ cur.fetchone()
+
+ @provide_gcp_context(GCP_GCS_KEY)
+ def setUp(self):
+ super().setUp()
+ self.init_connection()
+ self.create_gcs_bucket(GCS_BUCKET)
+ with suppress(Exception):
+ self.drop_db()
+ self.init_db()
+ self.execute_with_ctx(
+ ["bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}"],
+ key=GCP_BIGQUERY_KEY,
+ )
+
+ @provide_gcp_context(GCP_BIGQUERY_KEY)
+ def test_run_example_dag(self):
+ self.run_dag("example_trino_to_gcs", CLOUD_DAG_FOLDER)
+
+ @provide_gcp_context(GCP_GCS_KEY)
+ def tearDown(self):
+ self.delete_gcs_bucket(GCS_BUCKET)
+ self.drop_db()
+ self.execute_with_ctx(
+ ["bq", "rm", "--recursive", "--force", f"{self._project_id()}:{DATASET_NAME}"],
+ key=GCP_BIGQUERY_KEY,
+ )
+ super().tearDown()
diff --git a/tests/providers/mysql/transfers/test_trino_to_mysql.py b/tests/providers/mysql/transfers/test_trino_to_mysql.py
new file mode 100644
index 0000000000000..2e23169cc5f29
--- /dev/null
+++ b/tests/providers/mysql/transfers/test_trino_to_mysql.py
@@ -0,0 +1,73 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import unittest
+from unittest.mock import patch
+
+from airflow.providers.mysql.transfers.trino_to_mysql import TrinoToMySqlOperator
+from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment
+
+
+class TestTrinoToMySqlTransfer(TestHiveEnvironment):
+ def setUp(self):
+ self.kwargs = dict(
+ sql='sql',
+ mysql_table='mysql_table',
+ task_id='test_trino_to_mysql_transfer',
+ )
+ super().setUp()
+
+ @patch('airflow.providers.mysql.transfers.trino_to_mysql.MySqlHook')
+ @patch('airflow.providers.mysql.transfers.trino_to_mysql.TrinoHook')
+ def test_execute(self, mock_trino_hook, mock_mysql_hook):
+ TrinoToMySqlOperator(**self.kwargs).execute(context={})
+
+ mock_trino_hook.return_value.get_records.assert_called_once_with(self.kwargs['sql'])
+ mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
+ table=self.kwargs['mysql_table'], rows=mock_trino_hook.return_value.get_records.return_value
+ )
+
+ @patch('airflow.providers.mysql.transfers.trino_to_mysql.MySqlHook')
+ @patch('airflow.providers.mysql.transfers.trino_to_mysql.TrinoHook')
+ def test_execute_with_mysql_preoperator(self, mock_trino_hook, mock_mysql_hook):
+ self.kwargs.update(dict(mysql_preoperator='mysql_preoperator'))
+
+ TrinoToMySqlOperator(**self.kwargs).execute(context={})
+
+ mock_trino_hook.return_value.get_records.assert_called_once_with(self.kwargs['sql'])
+ mock_mysql_hook.return_value.run.assert_called_once_with(self.kwargs['mysql_preoperator'])
+ mock_mysql_hook.return_value.insert_rows.assert_called_once_with(
+ table=self.kwargs['mysql_table'], rows=mock_trino_hook.return_value.get_records.return_value
+ )
+
+ @unittest.skipIf(
+ 'AIRFLOW_RUNALL_TESTS' not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set"
+ )
+ def test_trino_to_mysql(self):
+ op = TrinoToMySqlOperator(
+ task_id='trino_to_mysql_check',
+ sql="""
+ SELECT name, count(*) as ccount
+ FROM airflow.static_babynames
+ GROUP BY name
+ """,
+ mysql_table='test_static_babynames',
+ mysql_preoperator='TRUNCATE TABLE test_static_babynames;',
+ dag=self.dag,
+ )
+ op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py
index f9e85875fd8de..e6ebb737c2c09 100644
--- a/tests/providers/presto/hooks/test_presto.py
+++ b/tests/providers/presto/hooks/test_presto.py
@@ -206,28 +206,3 @@ def test_get_pandas_df(self):
assert result_sets[1][0] == df.values.tolist()[1][0]
self.cur.execute.assert_called_once_with(statement, None)
-
-
-class TestPrestoHookIntegration(unittest.TestCase):
- @pytest.mark.integration("presto")
- @mock.patch.dict('os.environ', AIRFLOW_CONN_PRESTO_DEFAULT="presto://airflow@presto:8080/")
- def test_should_record_records(self):
- hook = PrestoHook()
- sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
- records = hook.get_records(sql)
- assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
-
- @pytest.mark.integration("presto")
- @pytest.mark.integration("kerberos")
- def test_should_record_records_with_kerberos_auth(self):
- conn_url = (
- 'presto://airflow@presto:7778/?'
- 'auth=kerberos&kerberos__service_name=HTTP&'
- 'verify=False&'
- 'protocol=https'
- )
- with mock.patch.dict('os.environ', AIRFLOW_CONN_PRESTO_DEFAULT=conn_url):
- hook = PrestoHook()
- sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
- records = hook.get_records(sql)
- assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
diff --git a/tests/providers/trino/__init__.py b/tests/providers/trino/__init__.py
new file mode 100644
index 0000000000000..217e5db960782
--- /dev/null
+++ b/tests/providers/trino/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/trino/hooks/__init__.py b/tests/providers/trino/hooks/__init__.py
new file mode 100644
index 0000000000000..217e5db960782
--- /dev/null
+++ b/tests/providers/trino/hooks/__init__.py
@@ -0,0 +1,17 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py
new file mode 100644
index 0000000000000..e649d2bece789
--- /dev/null
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -0,0 +1,233 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import json
+import re
+import unittest
+from unittest import mock
+from unittest.mock import patch
+
+import pytest
+from parameterized import parameterized
+from trino.transaction import IsolationLevel
+
+from airflow import AirflowException
+from airflow.models import Connection
+from airflow.providers.trino.hooks.trino import TrinoHook
+
+
+class TestTrinoHookConn(unittest.TestCase):
+ @patch('airflow.providers.trino.hooks.trino.trino.auth.BasicAuthentication')
+ @patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
+ @patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
+ def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic_auth):
+ mock_get_connection.return_value = Connection(
+ login='login', password='password', host='host', schema='hive'
+ )
+
+ conn = TrinoHook().get_conn()
+ mock_connect.assert_called_once_with(
+ catalog='hive',
+ host='host',
+ port=None,
+ http_scheme='http',
+ schema='hive',
+ source='airflow',
+ user='login',
+ isolation_level=0,
+ auth=mock_basic_auth.return_value,
+ )
+ mock_basic_auth.assert_called_once_with('login', 'password')
+ assert mock_connect.return_value == conn
+
+ @patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
+ def test_get_conn_invalid_auth(self, mock_get_connection):
+ mock_get_connection.return_value = Connection(
+ login='login',
+ password='password',
+ host='host',
+ schema='hive',
+ extra=json.dumps({'auth': 'kerberos'}),
+ )
+ with pytest.raises(
+ AirflowException, match=re.escape("Kerberos authorization doesn't support password.")
+ ):
+ TrinoHook().get_conn()
+
+ @patch('airflow.providers.trino.hooks.trino.trino.auth.KerberosAuthentication')
+ @patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
+ @patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
+ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_auth):
+ mock_get_connection.return_value = Connection(
+ login='login',
+ host='host',
+ schema='hive',
+ extra=json.dumps(
+ {
+ 'auth': 'kerberos',
+ 'kerberos__config': 'TEST_KERBEROS_CONFIG',
+ 'kerberos__service_name': 'TEST_SERVICE_NAME',
+ 'kerberos__mutual_authentication': 'TEST_MUTUAL_AUTHENTICATION',
+ 'kerberos__force_preemptive': True,
+ 'kerberos__hostname_override': 'TEST_HOSTNAME_OVERRIDE',
+ 'kerberos__sanitize_mutual_error_response': True,
+ 'kerberos__principal': 'TEST_PRINCIPAL',
+ 'kerberos__delegate': 'TEST_DELEGATE',
+ 'kerberos__ca_bundle': 'TEST_CA_BUNDLE',
+ }
+ ),
+ )
+
+ conn = TrinoHook().get_conn()
+ mock_connect.assert_called_once_with(
+ catalog='hive',
+ host='host',
+ port=None,
+ http_scheme='http',
+ schema='hive',
+ source='airflow',
+ user='login',
+ isolation_level=0,
+ auth=mock_auth.return_value,
+ )
+ mock_auth.assert_called_once_with(
+ ca_bundle='TEST_CA_BUNDLE',
+ config='TEST_KERBEROS_CONFIG',
+ delegate='TEST_DELEGATE',
+ force_preemptive=True,
+ hostname_override='TEST_HOSTNAME_OVERRIDE',
+ mutual_authentication='TEST_MUTUAL_AUTHENTICATION',
+ principal='TEST_PRINCIPAL',
+ sanitize_mutual_error_response=True,
+ service_name='TEST_SERVICE_NAME',
+ )
+ assert mock_connect.return_value == conn
+
+ @parameterized.expand(
+ [
+ ('False', False),
+ ('false', False),
+ ('true', True),
+ ('true', True),
+ ('/tmp/cert.crt', '/tmp/cert.crt'),
+ ]
+ )
+ def test_get_conn_verify(self, current_verify, expected_verify):
+ patcher_connect = patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
+ patcher_get_connections = patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
+
+ with patcher_connect as mock_connect, patcher_get_connections as mock_get_connection:
+ mock_get_connection.return_value = Connection(
+ login='login', host='host', schema='hive', extra=json.dumps({'verify': current_verify})
+ )
+ mock_verify = mock.PropertyMock()
+ type(mock_connect.return_value._http_session).verify = mock_verify
+
+ conn = TrinoHook().get_conn()
+ mock_verify.assert_called_once_with(expected_verify)
+ assert mock_connect.return_value == conn
+
+
+class TestTrinoHook(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+
+ self.cur = mock.MagicMock()
+ self.conn = mock.MagicMock()
+ self.conn.cursor.return_value = self.cur
+ conn = self.conn
+
+ class UnitTestTrinoHook(TrinoHook):
+ conn_name_attr = 'test_conn_id'
+
+ def get_conn(self):
+ return conn
+
+ def get_isolation_level(self):
+ return IsolationLevel.READ_COMMITTED
+
+ self.db_hook = UnitTestTrinoHook()
+
+ @patch('airflow.hooks.dbapi.DbApiHook.insert_rows')
+ def test_insert_rows(self, mock_insert_rows):
+ table = "table"
+ rows = [("hello",), ("world",)]
+ target_fields = None
+ commit_every = 10
+ self.db_hook.insert_rows(table, rows, target_fields, commit_every)
+ mock_insert_rows.assert_called_once_with(table, rows, None, 10)
+
+ def test_get_first_record(self):
+ statement = 'SQL'
+ result_sets = [('row1',), ('row2',)]
+ self.cur.fetchone.return_value = result_sets[0]
+
+ assert result_sets[0] == self.db_hook.get_first(statement)
+ self.conn.close.assert_called_once_with()
+ self.cur.close.assert_called_once_with()
+ self.cur.execute.assert_called_once_with(statement)
+
+ def test_get_records(self):
+ statement = 'SQL'
+ result_sets = [('row1',), ('row2',)]
+ self.cur.fetchall.return_value = result_sets
+
+ assert result_sets == self.db_hook.get_records(statement)
+ self.conn.close.assert_called_once_with()
+ self.cur.close.assert_called_once_with()
+ self.cur.execute.assert_called_once_with(statement)
+
+ def test_get_pandas_df(self):
+ statement = 'SQL'
+ column = 'col'
+ result_sets = [('row1',), ('row2',)]
+ self.cur.description = [(column,)]
+ self.cur.fetchall.return_value = result_sets
+ df = self.db_hook.get_pandas_df(statement)
+
+ assert column == df.columns[0]
+
+ assert result_sets[0][0] == df.values.tolist()[0][0]
+ assert result_sets[1][0] == df.values.tolist()[1][0]
+
+ self.cur.execute.assert_called_once_with(statement, None)
+
+
+class TestTrinoHookIntegration(unittest.TestCase):
+ @pytest.mark.integration("trino")
+ @mock.patch.dict('os.environ', AIRFLOW_CONN_TRINO_DEFAULT="trino://airflow@trino:8080/")
+ def test_should_record_records(self):
+ hook = TrinoHook()
+ sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
+ records = hook.get_records(sql)
+ assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records
+
+ @pytest.mark.integration("trino")
+ @pytest.mark.integration("kerberos")
+ def test_should_record_records_with_kerberos_auth(self):
+ conn_url = (
+ 'trino://airflow@trino.example.com:7778/?'
+ 'auth=kerberos&kerberos__service_name=HTTP&'
+ 'verify=False&'
+ 'protocol=https'
+ )
+ with mock.patch.dict('os.environ', AIRFLOW_CONN_TRINO_DEFAULT=conn_url):
+ hook = TrinoHook()
+ sql = "SELECT name FROM tpch.sf1.customer ORDER BY custkey ASC LIMIT 3"
+ records = hook.get_records(sql)
+ assert [['Customer#000000001'], ['Customer#000000002'], ['Customer#000000003']] == records