From 4e926d5c54201444aa6fc5a457b2335c118a94f6 Mon Sep 17 00:00:00 2001 From: josix Date: Fri, 25 Oct 2024 23:57:18 +0800 Subject: [PATCH] refactor(utils/decorators): rewrite remove task decorator to use ast --- airflow/utils/decorators.py | 63 ++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/airflow/utils/decorators.py b/airflow/utils/decorators.py index e299999423e56..74e8b2a77e98a 100644 --- a/airflow/utils/decorators.py +++ b/airflow/utils/decorators.py @@ -17,55 +17,46 @@ # under the License. from __future__ import annotations +import ast import sys -from collections import deque from typing import Callable, TypeVar T = TypeVar("T", bound=Callable) +class _TaskDecoratorRemover(ast.NodeTransformer): + def __init__(self, task_decorator_name): + self.decorators_to_remove = { + "setup", + "teardown", + "task.skip_if", + "task.run_if", + task_decorator_name, + } + + def visit_FunctionDef(self, node): + node.decorator_list = [d for d in node.decorator_list if not self._is_task_decorator(d)] + return self.generic_visit(node) + + def _is_task_decorator(self, decorator): + if isinstance(decorator, ast.Name): + return decorator.id in self.decorators_to_remove + elif isinstance(decorator, ast.Attribute): + return f"{decorator.value.id}.{decorator.attr}" in self.decorators_to_remove + return False + + def remove_task_decorator(python_source: str, task_decorator_name: str) -> str: """ Remove @task or similar decorators as well as @setup and @teardown. :param python_source: The python source code :param task_decorator_name: the decorator name - - TODO: Python 3.9+: Rewrite this to use ast.parse and ast.unparse """ - - def _remove_task_decorator(py_source, decorator_name): - # if no line starts with @decorator_name, we can early exit - for line in py_source.split("\n"): - if line.startswith(decorator_name): - break - else: - return python_source - split = python_source.split(decorator_name, 1) - before_decorator, after_decorator = split[0], split[1] - if after_decorator[0] == "(": - after_decorator = _balance_parens(after_decorator) - if after_decorator[0] == "\n": - after_decorator = after_decorator[1:] - return before_decorator + after_decorator - - decorators = ["@setup", "@teardown", "@task.skip_if", "@task.run_if", task_decorator_name] - for decorator in decorators: - python_source = _remove_task_decorator(python_source, decorator) - return python_source - - -def _balance_parens(after_decorator): - num_paren = 1 - after_decorator = deque(after_decorator) - after_decorator.popleft() - while num_paren: - current = after_decorator.popleft() - if current == "(": - num_paren = num_paren + 1 - elif current == ")": - num_paren = num_paren - 1 - return "".join(after_decorator) + tree = ast.parse(python_source) + remover = _TaskDecoratorRemover(task_decorator_name) + tree = remover.visit(tree) + return ast.unparse(tree) class _autostacklevel_warn: