Skip to content

Commit

Permalink
refactor(utils/decorators): rewrite remove task decorator to use ast
Browse files Browse the repository at this point in the history
  • Loading branch information
josix committed Jan 27, 2025
1 parent b998715 commit 066d757
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 54 deletions.
73 changes: 37 additions & 36 deletions airflow/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,55 @@
from __future__ import annotations

import sys
from collections import deque
from typing import Callable, TypeVar

import libcst as cst

T = TypeVar("T", bound=Callable)


class _TaskDecoratorRemover(cst.CSTTransformer):
def __init__(self, task_decorator_name):
self.decorators_to_remove = {
"setup",
"teardown",
"task.skip_if",
"task.run_if",
task_decorator_name.strip("@"),
}

def _is_task_decorator(self, decorator: cst.Decorator) -> bool:
if isinstance(decorator.decorator, cst.Name):
return decorator.decorator.value in self.decorators_to_remove
elif isinstance(decorator.decorator, cst.Attribute):
if isinstance(decorator.decorator.value, cst.Name):
return (
f"{decorator.decorator.value.value}.{decorator.decorator.attr.value}"
in self.decorators_to_remove
)
elif isinstance(decorator.decorator, cst.Call):
return self._is_task_decorator(cst.Decorator(decorator=decorator.decorator.func))
return False

def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
) -> cst.FunctionDef:
new_decorators = [dec for dec in updated_node.decorators if not self._is_task_decorator(dec)]
if len(new_decorators) == len(updated_node.decorators):
return updated_node
return updated_node.with_changes(decorators=new_decorators)


def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
"""
Remove @task or similar decorators as well as @setup and @teardown.
:param python_source: The python source code
:param task_decorator_name: the decorator name
TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse
"""

def _remove_task_decorator(py_source, decorator_name):
# if no line starts with @decorator_name, we can early exit
for line in py_source.split("\n"):
if line.startswith(decorator_name):
break
else:
return python_source
split = python_source.split(decorator_name, 1)
before_decorator, after_decorator = split[0], split[1]
if after_decorator[0] == "(":
after_decorator = _balance_parens(after_decorator)
if after_decorator[0] == "\n":
after_decorator = after_decorator[1:]
return before_decorator + after_decorator

decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name]
for decorator in decorators:
python_source = _remove_task_decorator(python_source, decorator)
return python_source


def _balance_parens(after_decorator):
num_paren = 1
after_decorator = deque(after_decorator)
after_decorator.popleft()
while num_paren:
current = after_decorator.popleft()
if current == "(":
num_paren = num_paren + 1
elif current == ")":
num_paren = num_paren - 1
return "".join(after_decorator)
source_tree = cst.parse_module(python_source)
modified_tree = source_tree.visit(_TaskDecoratorRemover(task_decorator_name))
return modified_tree.code


class _autostacklevel_warn:
Expand Down
1 change: 1 addition & 0 deletions hatch_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@
"jinja2>=3.0.0",
"jsonschema>=4.18.0",
"lazy-object-proxy>=1.2.0",
"libcst >=1.1.0",
"linkify-it-py>=2.0.0",
"lockfile>=0.12.2",
"markdown-it-py>=2.1.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,25 +192,25 @@ def test_should_create_virtualenv_with_extra_packages_uv(self, mock_execute_in_s
)

def test_remove_task_decorator(self):
py_source = '@task.virtualenv(serializer="dill")\ndef f():\nimport funcsigs'
py_source = '@task.virtualenv(serializer="dill")\ndef f():\n import funcsigs'
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"
assert res == "def f():\n import funcsigs"

def test_remove_decorator_no_parens(self):
py_source = "@task.virtualenv\ndef f():\nimport funcsigs"
py_source = "@task.virtualenv\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\nimport funcsigs"
assert res == "def f():\n import funcsigs"

def test_remove_decorator_including_comment(self):
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\nimport funcsigs"
py_source = "@task.virtualenv\ndef f():\n# @task.virtualenv\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "def f():\n# @task.virtualenv\nimport funcsigs"
assert res == "def f():\n# @task.virtualenv\n import funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\nimport funcsigs"
py_source = "@foo\n@task.virtualenv\n@bar\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == "@foo\n@bar\ndef f():\n import funcsigs"

py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\nimport funcsigs"
py_source = "@foo\n@task.virtualenv()\n@bar\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.virtualenv")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == "@foo\n@bar\ndef f():\n import funcsigs"
16 changes: 8 additions & 8 deletions tests/utils/test_preexisting_python_virtualenv_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@

class TestExternalPythonDecorator:
def test_remove_task_decorator(self):
py_source = '@task.external_python(serializer="dill")\ndef f():\nimport funcsigs'
py_source = '@task.external_python(serializer="dill")\ndef f():\n import funcsigs'
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\nimport funcsigs"
assert res == "def f():\n import funcsigs"

def test_remove_decorator_no_parens(self):
py_source = "@task.external_python\ndef f():\nimport funcsigs"
py_source = "@task.external_python\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "def f():\nimport funcsigs"
assert res == "def f():\n import funcsigs"

def test_remove_decorator_nested(self):
py_source = "@foo\n@task.external_python\n@bar\ndef f():\nimport funcsigs"
py_source = "@foo\n@task.external_python\n@bar\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == "@foo\n@bar\ndef f():\n import funcsigs"

py_source = "@foo\n@task.external_python()\n@bar\ndef f():\nimport funcsigs"
py_source = "@foo\n@task.external_python()\n@bar\ndef f():\n import funcsigs"
res = remove_task_decorator(python_source=py_source, task_decorator_name="@task.external_python")
assert res == "@foo\n@bar\ndef f():\nimport funcsigs"
assert res == "@foo\n@bar\ndef f():\n import funcsigs"

0 comments on commit 066d757

Please sign in to comment.