diff --git a/src/latch_cli/centromere/ast_parsing.py b/src/latch_cli/centromere/ast_parsing.py new file mode 100644 index 00000000..b0b128e4 --- /dev/null +++ b/src/latch_cli/centromere/ast_parsing.py @@ -0,0 +1,126 @@ +import ast +from dataclasses import dataclass +from pathlib import Path +from textwrap import dedent +from typing import Literal, Optional + +import click + + +@dataclass +class FlyteObject: + type: Literal["task", "workflow"] + name: str + dockerfile: Optional[Path] = None + + +def is_task_decorator(decorator_name: str) -> bool: + return decorator_name in { + # og + "small_task", + "medium_task", + "large_task", + # og gpu + "small_gpu_task", + "large_gpu_task", + # custom + "custom_task", + "custom_memory_optimized_task", + # nf + "nextflow_runtime_task", + # l40s gpu + "g6e_xlarge_task", + "g6e_2xlarge_task", + "g6e_4xlarge_task", + "g6e_8xlarge_task", + "g6e_12xlarge_task", + "g6e_16xlarge_task", + "g6e_24xlarge_task", + # v100 gpu + "v100_x1_task", + "v100_x4_task", + "v100_x8_task", + } + + +class Visitor(ast.NodeVisitor): + def __init__(self, file: Path): + self.file = file + self.flyte_objects: list[FlyteObject] = [] + + def visit_FunctionDef(self, node: ast.FunctionDef): + if len(node.decorator_list) == 0: + return self.generic_visit(node) + + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name): + if decorator.id == "workflow": + self.flyte_objects.append(FlyteObject("workflow", node.name)) + elif is_task_decorator(decorator.id): + self.flyte_objects.append(FlyteObject("task", node.name)) + + elif isinstance(decorator, ast.Call): + func = decorator.func + assert isinstance(func, ast.Name) + + if not is_task_decorator(func.id) and func.id != "workflow": + continue + + if func.id == "workflow": + self.flyte_objects.append(FlyteObject("workflow", node.name)) + continue + + dockerfile: Optional[Path] = None + for kw in decorator.keywords: + if kw.arg != "dockerfile": + continue + + try: + dockerfile = Path(ast.literal_eval(kw.value)).resolve() + except ValueError as e: + click.secho( + dedent(f"""\ + There was an issue parsing the `dockerfile` argument for task `{node.name}` in {self.file}. + Note that values passed to `dockerfile` must be string literals. + """), + fg="red", + ) + + raise click.exceptions.Exit(1) from e + + if not dockerfile.exists(): + click.secho( + f"The `dockerfile` value ({dockerfile}) for task `{node.name}` in {self.file} does not exist.", + fg="red", + ) + + raise click.exceptions.Exit(1) + + self.flyte_objects.append(FlyteObject("task", node.name, dockerfile)) + + return self.generic_visit(node) + + +def get_flyte_objects(file: Path) -> list[FlyteObject]: + res = [] + if file.is_dir(): + for child in file.iterdir(): + res.extend(get_flyte_objects(child)) + + return res + + assert file.is_file() + if file.suffix != ".py": + return res + + v = Visitor(file.resolve()) + + try: + parsed = ast.parse(file.read_text(), filename=file) + except SyntaxError as e: + click.secho(f"There is a syntax error in {file}: {e}", fg="red") + raise click.exceptions.Exit(1) from e + + v.visit(parsed) + + return v.flyte_objects diff --git a/src/latch_cli/centromere/ctx.py b/src/latch_cli/centromere/ctx.py index 3a2cdb64..f63b59dc 100644 --- a/src/latch_cli/centromere/ctx.py +++ b/src/latch_cli/centromere/ctx.py @@ -11,17 +11,14 @@ import paramiko import paramiko.util from docker.transport import SSHHTTPAdapter -from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import FlyteEntities -from flytekit.core.workflow import PythonFunctionWorkflow import latch_cli.tinyrequests as tinyrequests from latch.utils import account_id_from_token, current_workspace, retrieve_or_login +from latch_cli.centromere.ast_parsing import get_flyte_objects from latch_cli.centromere.utils import ( RemoteConnInfo, _construct_dkr_client, _construct_ssh_client, - _import_flyte_objects, ) from latch_cli.constants import docker_image_name_illegal_pat, latch_constants from latch_cli.docker_utils import get_default_dockerfile @@ -176,8 +173,8 @@ def __init__( if self.workflow_type == WorkflowType.latchbiosdk: try: - _import_flyte_objects([self.pkg_root], module_name=self.wf_module) - except ModuleNotFoundError: + flyte_objects = get_flyte_objects(self.pkg_root / self.wf_module) + except ModuleNotFoundError as e: click.secho( dedent( f""" @@ -189,14 +186,23 @@ def __init__( ), fg="red", ) - raise click.exceptions.Exit(1) + raise click.exceptions.Exit(1) from e + + wf_name: Optional[str] = None + + name_path = pkg_root / latch_constants.pkg_workflow_name + if name_path.exists(): + wf_name = name_path.read_text().strip() + + if wf_name is None: + for obj in flyte_objects: + if obj.type != "workflow": + continue - for entity in FlyteEntities.entities: - if isinstance(entity, PythonFunctionWorkflow): - self.workflow_name = entity.name + wf_name = obj.name break - if not hasattr(self, "workflow_name"): + if wf_name is None: click.secho( dedent("""\ Unable to locate workflow code. If you are a registering a Snakemake project, make sure to pass the Snakefile path with the --snakefile flag. @@ -205,21 +211,17 @@ def __init__( ) raise click.exceptions.Exit(1) - name_path = pkg_root / latch_constants.pkg_workflow_name - if name_path.exists(): - self.workflow_name = name_path.read_text().strip() + self.workflow_name = wf_name - for entity in FlyteEntities.entities: - if isinstance(entity, PythonTask): - if ( - hasattr(entity, "dockerfile_path") - and entity.dockerfile_path is not None - ): - self.container_map[entity.name] = _Container( - dockerfile=entity.dockerfile_path, - image_name=self.task_image_name(entity.name), - pkg_dir=entity.dockerfile_path.parent, - ) + for obj in flyte_objects: + if obj.type != "task" or obj.dockerfile is None: + continue + + self.container_map[obj.name] = _Container( + dockerfile=obj.dockerfile, + image_name=self.task_image_name(obj.name), + pkg_dir=obj.dockerfile.parent, + ) elif self.workflow_type == WorkflowType.snakemake: assert snakefile is not None diff --git a/src/latch_cli/centromere/utils.py b/src/latch_cli/centromere/utils.py index 5fb3e120..9514f757 100644 --- a/src/latch_cli/centromere/utils.py +++ b/src/latch_cli/centromere/utils.py @@ -1,6 +1,5 @@ import builtins import contextlib -import functools import os import random import string @@ -12,10 +11,8 @@ from typing import Callable, Iterator, List, Optional, TypeVar import docker +import docker.errors import paramiko -from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import FileAccessProvider -from flytekit.tools import module_loader from typing_extensions import ParamSpec from latch_cli.constants import latch_constants @@ -42,6 +39,10 @@ def _add_sys_paths(paths: List[Path]) -> Iterator[None]: def _import_flyte_objects(paths: List[Path], module_name: str = "wf"): + from flytekit.core.context_manager import FlyteContext, FlyteContextManager + from flytekit.core.data_persistence import FileAccessProvider + from flytekit.tools import module_loader + with _add_sys_paths(paths): class FakeModule(ModuleType): @@ -76,20 +77,15 @@ def __new__(*args, **kwargs): def fake_import(name, globals=None, locals=None, fromlist=(), level=0): try: return real_import( - name, - globals=globals, - locals=locals, - fromlist=fromlist, - level=level, + name, globals=globals, locals=locals, fromlist=fromlist, level=level ) - except (ModuleNotFoundError, AttributeError) as e: + except (ModuleNotFoundError, AttributeError): return FakeModule(name) # Temporary ctx tells lytekit to skip local execution when # inspecting objects fap = FileAccessProvider( - local_sandbox_dir=tempfile.mkdtemp(prefix="foo"), - raw_output_prefix="bar", + local_sandbox_dir=tempfile.mkdtemp(prefix="foo"), raw_output_prefix="bar" ) tmp_context = FlyteContext(fap, inspect_objects_only=True) @@ -201,9 +197,7 @@ def _construct_ssh_client( raise ConnectionError("unable to create connection to jump host") sock = gateway_transport.open_channel( - kind="direct-tcpip", - dest_addr=(remote_conn_info.ip, 22), - src_addr=("", 0), + kind="direct-tcpip", dest_addr=(remote_conn_info.ip, 22), src_addr=("", 0) ) else: sock = None @@ -214,10 +208,7 @@ def _construct_ssh_client( ssh.load_system_host_keys() ssh.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy) ssh.connect( - remote_conn_info.ip, - username=remote_conn_info.username, - sock=sock, - pkey=pkey, + remote_conn_info.ip, username=remote_conn_info.username, sock=sock, pkey=pkey ) transport = ssh.get_transport() diff --git a/uv.lock b/uv.lock index 4ea2284c..b8291b65 100644 --- a/uv.lock +++ b/uv.lock @@ -1005,7 +1005,7 @@ wheels = [ [[package]] name = "latch" -version = "2.54.10" +version = "2.55.3" source = { editable = "." } dependencies = [ { name = "aioconsole" },