Skip to content

Prune sqlglot transpiler and customization #1452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
  •  
  •  
  •  
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ setup_python:


dev:
pip install hatch
hatch env create
hatch run pip install --upgrade pip
hatch run pip install -e '.[test]'
Expand Down
13 changes: 6 additions & 7 deletions src/databricks/labs/remorph/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@
from databricks.labs.remorph.install import WorkspaceInstaller
from databricks.labs.remorph.reconcile.runner import ReconcileRunner
from databricks.labs.remorph.lineage import lineage_generator
from databricks.labs.remorph.reconcile.utils import dialect_exists
from databricks.labs.remorph.transpiler.execute import transpile as do_transpile
from databricks.labs.remorph.reconcile.execute import RECONCILE_OPERATION_NAME, AGG_RECONCILE_OPERATION_NAME
from databricks.labs.remorph.jvmproxy import proxy_command
from databricks.labs.remorph.transpiler.lsp_engine import LSPEngine
from databricks.sdk.core import with_user_agent_extra

from databricks.sdk import WorkspaceClient

from databricks.labs.remorph.transpiler.sqlglot.sqlglot_engine import SqlglotEngine
from databricks.labs.remorph.transpiler.transpile_engine import TranspileEngine

remorph = App(__file__)
logger = get_logger(__file__)

Expand Down Expand Up @@ -88,7 +87,7 @@ def transpile(
if not default_config:
raise SystemExit("Installed transpile config not found. Please install Remorph transpile first.")
_override_workspace_client_config(ctx, default_config.sdk_config)
engine = TranspileEngine.load_engine(Path(transpiler_config_path))
engine = LSPEngine.from_config_path(Path(transpiler_config_path))
engine.check_source_dialect(source_dialect)
if not input_source or not os.path.exists(input_source):
raise_validation_exception(f"Invalid value for '--input-source': Path '{input_source}' does not exist.")
Expand Down Expand Up @@ -178,14 +177,14 @@ def generate_lineage(w: WorkspaceClient, source_dialect: str, input_source: str,
"""[Experimental] Generates a lineage of source SQL files or folder"""
ctx = ApplicationContext(w)
logger.debug(f"User: {ctx.current_user}")
engine = SqlglotEngine()
engine.check_source_dialect(source_dialect)
if not source_dialect or not dialect_exists(source_dialect):
raise_validation_exception(f"Invalid value for '--source-dialect': {source_dialect}.")
if not input_source or not os.path.exists(input_source):
raise_validation_exception(f"Invalid value for '--input-source': Path '{input_source}' does not exist.")
if not os.path.exists(output_folder) or output_folder in {None, ""}:
raise_validation_exception(f"Invalid value for '--output-folder': Path '{output_folder}' does not exist.")

lineage_generator(engine, source_dialect, input_source, output_folder)
lineage_generator(source_dialect, input_source, output_folder)


@remorph.command
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path

from sqlglot.dialects import Snowflake, Databricks

from databricks.labs.blueprint.wheels import ProductInfo
from databricks.labs.remorph.coverage import commons
from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks
from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake

if __name__ == "__main__":
input_dir = commons.get_env_var("INPUT_DIR_PARENT", required=True)
Expand Down
2 changes: 1 addition & 1 deletion src/databricks/labs/remorph/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from databricks.labs.remorph.deployment.configurator import ResourceConfigurator
from databricks.labs.remorph.deployment.installation import WorkspaceInstallation
from databricks.labs.remorph.reconcile.constants import ReconReportType, ReconSourceType
from databricks.labs.remorph.transpiler.lsp.lsp_engine import LSPConfig
from databricks.labs.remorph.transpiler.lsp_engine import LSPConfig


logger = logging.getLogger(__name__)
Expand Down
30 changes: 25 additions & 5 deletions src/databricks/labs/remorph/intermediate/root_tables.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import logging
from collections.abc import Iterable
from pathlib import Path
from sqlglot import parse, Expression, ErrorLevel
from sqlglot.expressions import Create, Insert, Merge, Join, Select, Table, With

from databricks.labs.remorph.helpers.file_utils import (
get_sql_file,
Expand All @@ -8,15 +11,12 @@
)
from databricks.labs.remorph.intermediate.dag import DAG

from databricks.labs.remorph.transpiler.sqlglot.sqlglot_engine import SqlglotEngine

logger = logging.getLogger(__name__)


class RootTableAnalyzer:

def __init__(self, engine: SqlglotEngine, source_dialect: str, input_path: Path):
self.engine = engine
def __init__(self, source_dialect: str, input_path: Path):
self.source_dialect = source_dialect
self.input_path = input_path

Expand All @@ -39,6 +39,26 @@ def generate_lineage_dag(self) -> DAG:
return dag

def _populate_dag(self, sql_content: str, path: Path, dag: DAG):
for root_table, child in self.engine.analyse_table_lineage(self.source_dialect, sql_content, path):
for root_table, child in self._analyse_table_lineage(sql_content, path):
dag.add_node(child)
dag.add_edge(root_table, child)

def _analyse_table_lineage(self, source_code: str, file_path: Path) -> Iterable[tuple[str, str]]:
parsed = parse(source_code, read=self.source_dialect, error_level=ErrorLevel.IMMEDIATE)
if parsed is not None:
for expr in parsed:
child: str = str(file_path)
if expr is not None:
# TODO: fix possible issue where the file reference is lost (if we have a 'create')
for change in expr.find_all(Create, Insert, Merge, bfs=False):
child = self._find_root_table(change)

for query in expr.find_all(Select, Join, With, bfs=False):
table = self._find_root_table(query)
if table:
yield table, child

@staticmethod
def _find_root_table(exp: Expression) -> str:
table = exp.find(Table, bfs=False)
return table.name if table else ""
5 changes: 2 additions & 3 deletions src/databricks/labs/remorph/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from databricks.labs.remorph.intermediate.dag import DAG
from databricks.labs.remorph.intermediate.root_tables import RootTableAnalyzer
from databricks.labs.remorph.transpiler.sqlglot.sqlglot_engine import SqlglotEngine

logger = logging.getLogger(__name__)

Expand All @@ -21,13 +20,13 @@ def _generate_dot_file_contents(dag: DAG) -> str:
return _lineage_str


def lineage_generator(engine: SqlglotEngine, source_dialect: str, input_source: str, output_folder: str):
def lineage_generator(source_dialect: str, input_source: str, output_folder: str):
input_sql_path = Path(input_source)
output_folder = output_folder if output_folder.endswith('/') else output_folder + '/'

msg = f"Processing for SQLs at this location: {input_sql_path}"
logger.info(msg)
root_table_analyzer = RootTableAnalyzer(engine, source_dialect, input_sql_path)
root_table_analyzer = RootTableAnalyzer(source_dialect, input_sql_path)
generated_dag = root_table_analyzer.generate_lineage_dag()
lineage_file_content = _generate_dot_file_contents(generated_dag)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from pyspark.sql import SparkSession
from sqlglot import Dialect
from sqlglot.dialects import TSQL
from sqlglot.dialects import TSQL, Snowflake, Oracle, Databricks

from databricks.labs.remorph.reconcile.connectors.data_source import DataSource
from databricks.labs.remorph.reconcile.connectors.databricks import DatabricksDataSource
from databricks.labs.remorph.reconcile.connectors.oracle import OracleDataSource
from databricks.labs.remorph.reconcile.connectors.snowflake import SnowflakeDataSource
from databricks.labs.remorph.reconcile.connectors.sql_server import SQLServerDataSource
from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks
from databricks.labs.remorph.transpiler.sqlglot.parsers.oracle import Oracle
from databricks.labs.remorph.transpiler.sqlglot.parsers.snowflake import Snowflake
from databricks.sdk import WorkspaceClient


Expand Down
2 changes: 1 addition & 1 deletion src/databricks/labs/remorph/reconcile/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from pyspark.sql import DataFrame, SparkSession
from sqlglot import Dialect

from databricks.labs.remorph.reconcile.utils import get_dialect
from databricks.labs.remorph.config import (
DatabaseConfig,
TableRecon,
ReconcileConfig,
ReconcileMetadataConfig,
)
from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect
from databricks.labs.remorph.reconcile.compare import (
capture_mismatch_data_and_columns,
reconcile_data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _get_layer_query(self, group_list: list[Aggregate]) -> AggregateQueryRules:
layer=self.layer,
group_by_columns=group_list[0].group_by_columns,
group_by_columns_as_str=group_list[0].group_by_columns_as_str,
query=query_exp.sql(dialect=self.engine),
query=query_exp.sql(dialect=self._dialect),
rules=query_agg_rules,
)
return agg_query_rules
Expand Down
20 changes: 9 additions & 11 deletions src/databricks/labs/remorph/reconcile/query_builder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import sqlglot.expressions as exp
from sqlglot import Dialect, parse_one

from databricks.labs.remorph.reconcile.utils import get_dialect, get_dialect_name
from databricks.labs.remorph.reconcile.exception import InvalidInputException
from databricks.labs.remorph.reconcile.query_builder.expression_generator import (
DataType_transform_mapping,
transform_expression,
)
from databricks.labs.remorph.reconcile.recon_config import Schema, Table, Aggregate
from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect, SQLGLOT_DIALECTS

logger = logging.getLogger(__name__)

Expand All @@ -21,16 +21,16 @@ def __init__(
table_conf: Table,
schema: list[Schema],
layer: str,
engine: Dialect,
dialect: Dialect,
):
self._table_conf = table_conf
self._schema = schema
self._layer = layer
self._engine = engine
self._dialect = dialect

@property
def engine(self) -> Dialect:
return self._engine
def dialect(self) -> Dialect:
return self._dialect

@property
def layer(self) -> str:
Expand Down Expand Up @@ -93,7 +93,7 @@ def _apply_user_transformation(self, aliases: list[exp.Expression]) -> list[exp.

def _user_transformer(self, node: exp.Expression, user_transformations: dict[str, str]) -> exp.Expression:
if isinstance(node, exp.Column) and user_transformations:
dialect = self.engine if self.layer == "source" else get_dialect("databricks")
dialect = self._dialect if self.layer == "source" else get_dialect("databricks")
column_name = node.name
if column_name in user_transformations.keys():
return parse_one(user_transformations.get(column_name, column_name), read=dialect)
Expand All @@ -108,13 +108,11 @@ def _apply_default_transformation(
return with_transform

@staticmethod
def _default_transformer(node: exp.Expression, schema: list[Schema], source: Dialect) -> exp.Expression:
def _default_transformer(node: exp.Expression, schema: list[Schema], dialect: Dialect) -> exp.Expression:

def _get_transform(datatype: str):
source_dialects = [source_key for source_key, dialect in SQLGLOT_DIALECTS.items() if dialect == source]
source_dialect = source_dialects[0] if source_dialects else "universal"

source_mapping = DataType_transform_mapping.get(source_dialect, {})
dialect_name = get_dialect_name(dialect)
source_mapping = DataType_transform_mapping.get(dialect_name, {})

if source_mapping.get(datatype.upper()) is not None:
return source_mapping.get(datatype.upper())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlglot import Dialect
from sqlglot import expressions as exp

from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect
from databricks.labs.remorph.reconcile.utils import get_dialect
from databricks.labs.remorph.reconcile.recon_config import HashAlgoMapping


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sqlglot.expressions as exp
from sqlglot import Dialect

from databricks.labs.remorph.reconcile.utils import get_dialect
from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder
from databricks.labs.remorph.reconcile.query_builder.expression_generator import (
build_column,
Expand All @@ -11,7 +12,6 @@
lower,
transform_expression,
)
from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_dialect

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,7 +59,7 @@ def build_query(self, report_type: str) -> str:
)
hash_col_with_transform = [self._generate_hash_algorithm(hashcols_sorted_as_src_seq, _HASH_COLUMN_NAME)]

dialect = self.engine if self.layer == "source" else get_dialect("databricks")
dialect = self._dialect if self.layer == "source" else get_dialect("databricks")
res = (
exp.select(*hash_col_with_transform + key_cols_with_transform)
.from_(":tbl")
Expand All @@ -77,11 +77,11 @@ def _generate_hash_algorithm(
) -> exp.Expression:
cols_with_alias = [build_column(this=col, alias=None) for col in cols]
cols_with_transform = self.add_transformations(
cols_with_alias, self.engine if self.layer == "source" else get_dialect("databricks")
cols_with_alias, self._dialect if self.layer == "source" else get_dialect("databricks")
)
col_exprs = exp.select(*cols_with_transform).iter_expressions()
concat_expr = concat(list(col_exprs))

hash_expr = concat_expr.transform(_hash_transform, self.engine, self.layer).transform(lower, is_expr=True)
hash_expr = concat_expr.transform(_hash_transform, self._dialect, self.layer).transform(lower, is_expr=True)

return build_column(hash_expr, alias=column_alias)
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging

import sqlglot.expressions as exp
from pyspark.sql import DataFrame
from sqlglot import select
from pyspark.sql import DataFrame

from databricks.labs.remorph.transpiler.sqlglot.dialect_utils import get_key_from_dialect
from databricks.labs.remorph.reconcile.utils import get_dialect_name
from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder
from databricks.labs.remorph.reconcile.query_builder.expression_generator import (
build_column,
Expand Down Expand Up @@ -52,7 +52,7 @@ def build_query(self, df: DataFrame):
for col in cols
]

sql_with_transforms = self.add_transformations(cols_with_alias, self.engine)
sql_with_transforms = self.add_transformations(cols_with_alias, self._dialect)
query_sql = select(*sql_with_transforms).from_(":tbl").where(self.filter)
if self.layer == "source":
with_select = [build_column(this=col, table_name="src") for col in sorted(cols)]
Expand All @@ -69,7 +69,7 @@ def build_query(self, df: DataFrame):
.select(*with_select)
.from_("src")
.join(join_clause)
.sql(dialect=self.engine)
.sql(dialect=self._dialect)
)
logger.info(f"Sampling Query for {self.layer}: {query}")
return query
Expand Down Expand Up @@ -97,7 +97,7 @@ def _get_with_clause(self, df: DataFrame) -> exp.Select:
)
for col, value in zip(df.columns, row)
]
if get_key_from_dialect(self.engine) == "oracle":
if get_dialect_name(self._dialect) == "oracle":
union_res.append(select(*row_select).from_("dual"))
else:
union_res.append(select(*row_select))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sqlglot import expressions as exp
from sqlglot import select

from databricks.labs.remorph.reconcile.utils import get_dialect
from databricks.labs.remorph.reconcile.query_builder.base import QueryBuilder
from databricks.labs.remorph.reconcile.query_builder.expression_generator import (
anonymous,
Expand All @@ -17,7 +18,6 @@
coalesce,
)
from databricks.labs.remorph.reconcile.recon_config import ColumnThresholds
from databricks.labs.remorph.transpiler.sqlglot.generator.databricks import Databricks

logger = logging.getLogger(__name__)

Expand All @@ -34,7 +34,13 @@ def build_comparison_query(self) -> str:
select_clause, where = self._generate_select_where_clause(join_columns)
from_clause, join_clause = self._generate_from_and_join_clause(join_columns)
# for threshold comparison query the dialect is always Databricks
query = select(*select_clause).from_(from_clause).join(join_clause).where(where).sql(dialect=Databricks)
query = (
select(*select_clause)
.from_(from_clause)
.join(join_clause)
.where(where)
.sql(dialect=get_dialect("databricks"))
)
logger.info(f"Threshold Comparison query: {query}")
return query

Expand Down Expand Up @@ -226,6 +232,6 @@ def build_threshold_query(self) -> str:
if self.user_transformations:
thresholds_expr = self._apply_user_transformation(threshold_alias)

query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter)).sql(dialect=self.engine)
query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter)).sql(dialect=self._dialect)
logger.info(f"Threshold Query for {self.layer}: {query}")
return query
Loading
Loading