Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace importing the wf with ast parsing #516

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions src/latch_cli/centromere/ast_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import ast
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import Literal, Optional

import click


@dataclass
class FlyteObject:
type: Literal["task", "workflow"]
name: str
dockerfile: Optional[Path] = None


def is_task_decorator(decorator_name: str) -> bool:
return decorator_name in {
# og
"small_task",
"medium_task",
"large_task",
# og gpu
"small_gpu_task",
"large_gpu_task",
# custom
"custom_task",
"custom_memory_optimized_task",
# nf
"nextflow_runtime_task",
# l40s gpu
"g6e_xlarge_task",
"g6e_2xlarge_task",
"g6e_4xlarge_task",
"g6e_8xlarge_task",
"g6e_12xlarge_task",
"g6e_16xlarge_task",
"g6e_24xlarge_task",
# v100 gpu
"v100_x1_task",
"v100_x4_task",
"v100_x8_task",
}


class Visitor(ast.NodeVisitor):
def __init__(self, file: Path):
self.file = file
self.flyte_objects: list[FlyteObject] = []

def visit_FunctionDef(self, node: ast.FunctionDef):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to skip functions that are not at the top level? what does flyte do with e.g. class methods?

if len(node.decorator_list) == 0:
return self.generic_visit(node)

for decorator in node.decorator_list:
if isinstance(decorator, ast.Name):
if decorator.id == "workflow":
self.flyte_objects.append(FlyteObject("workflow", node.name))
elif is_task_decorator(decorator.id):
self.flyte_objects.append(FlyteObject("task", node.name))

elif isinstance(decorator, ast.Call):
func = decorator.func
assert isinstance(func, ast.Name)

if not is_task_decorator(func.id) and func.id != "workflow":
continue

if func.id == "workflow":
self.flyte_objects.append(FlyteObject("workflow", node.name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to store a qualified name? i.e. including the module names? flyte uses that iirc e.g. a workflow might be named src.main.wf because def wf() is in src/main.py

continue

dockerfile: Optional[Path] = None
for kw in decorator.keywords:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double-check that dockerfile is a keyword-only argument?

id leave a note here

if kw.arg != "dockerfile":
continue

try:
dockerfile = Path(ast.literal_eval(kw.value)).resolve()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

careful with the working directory here

if the path is relative you need to resolve it relative to project root probably? or the file that is being parsed? unclear rn

except ValueError as e:
click.secho(
dedent(f"""\
There was an issue parsing the `dockerfile` argument for task `{node.name}` in {self.file}.
Note that values passed to `dockerfile` must be string literals.
"""),
fg="red",
)

raise click.exceptions.Exit(1) from e

if not dockerfile.exists():
click.secho(
f"The `dockerfile` value ({dockerfile}) for task `{node.name}` in {self.file} does not exist.",
fg="red",
)

raise click.exceptions.Exit(1)

self.flyte_objects.append(FlyteObject("task", node.name, dockerfile))

return self.generic_visit(node)


def get_flyte_objects(file: Path) -> list[FlyteObject]:
res = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add explicit type

if file.is_dir():
for child in file.iterdir():
res.extend(get_flyte_objects(child))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to use a queue instead of recursing
the file tree might be very deep for some stupid reason

also might want to avoid recursing into directories that don't have a __init__.py since those cannot contain importable python files anyway


return res

assert file.is_file()
if file.suffix != ".py":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a great heuristic tbh
would be better to follow the import graph i think

leave a todo at least

return res

v = Visitor(file.resolve())

try:
parsed = ast.parse(file.read_text(), filename=file)
except SyntaxError as e:
click.secho(f"There is a syntax error in {file}: {e}", fg="red")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
click.secho(f"There is a syntax error in {file}: {e}", fg="red")
traceback.print_exc()
click.secho(f"\nRegistration failed due to a syntax error (see above)", fg="red")

raise click.exceptions.Exit(1) from e

v.visit(parsed)

return v.flyte_objects
52 changes: 27 additions & 25 deletions src/latch_cli/centromere/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,14 @@
import paramiko
import paramiko.util
from docker.transport import SSHHTTPAdapter
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteEntities
from flytekit.core.workflow import PythonFunctionWorkflow

import latch_cli.tinyrequests as tinyrequests
from latch.utils import account_id_from_token, current_workspace, retrieve_or_login
from latch_cli.centromere.ast_parsing import get_flyte_objects
from latch_cli.centromere.utils import (
RemoteConnInfo,
_construct_dkr_client,
_construct_ssh_client,
_import_flyte_objects,
)
from latch_cli.constants import docker_image_name_illegal_pat, latch_constants
from latch_cli.docker_utils import get_default_dockerfile
Expand Down Expand Up @@ -176,8 +173,8 @@ def __init__(

if self.workflow_type == WorkflowType.latchbiosdk:
try:
_import_flyte_objects([self.pkg_root], module_name=self.wf_module)
except ModuleNotFoundError:
flyte_objects = get_flyte_objects(self.pkg_root / self.wf_module)
except ModuleNotFoundError as e:
click.secho(
dedent(
f"""
Expand All @@ -189,14 +186,23 @@ def __init__(
),
fg="red",
)
raise click.exceptions.Exit(1)
raise click.exceptions.Exit(1) from e

wf_name: Optional[str] = None

name_path = pkg_root / latch_constants.pkg_workflow_name
if name_path.exists():
wf_name = name_path.read_text().strip()

if wf_name is None:
for obj in flyte_objects:
if obj.type != "workflow":
continue

for entity in FlyteEntities.entities:
if isinstance(entity, PythonFunctionWorkflow):
self.workflow_name = entity.name
wf_name = obj.name
break

if not hasattr(self, "workflow_name"):
if wf_name is None:
click.secho(
dedent("""\
Unable to locate workflow code. If you are a registering a Snakemake project, make sure to pass the Snakefile path with the --snakefile flag.
Expand All @@ -205,21 +211,17 @@ def __init__(
)
raise click.exceptions.Exit(1)

name_path = pkg_root / latch_constants.pkg_workflow_name
if name_path.exists():
self.workflow_name = name_path.read_text().strip()
self.workflow_name = wf_name

for entity in FlyteEntities.entities:
if isinstance(entity, PythonTask):
if (
hasattr(entity, "dockerfile_path")
and entity.dockerfile_path is not None
):
self.container_map[entity.name] = _Container(
dockerfile=entity.dockerfile_path,
image_name=self.task_image_name(entity.name),
pkg_dir=entity.dockerfile_path.parent,
)
for obj in flyte_objects:
if obj.type != "task" or obj.dockerfile is None:
continue

self.container_map[obj.name] = _Container(
dockerfile=obj.dockerfile,
image_name=self.task_image_name(obj.name),
pkg_dir=obj.dockerfile.parent,
)

elif self.workflow_type == WorkflowType.snakemake:
assert snakefile is not None
Expand Down
29 changes: 10 additions & 19 deletions src/latch_cli/centromere/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import builtins
import contextlib
import functools
import os
import random
import string
Expand All @@ -12,10 +11,8 @@
from typing import Callable, Iterator, List, Optional, TypeVar

import docker
import docker.errors
import paramiko
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools import module_loader
from typing_extensions import ParamSpec

from latch_cli.constants import latch_constants
Expand All @@ -42,6 +39,10 @@ def _add_sys_paths(paths: List[Path]) -> Iterator[None]:


def _import_flyte_objects(paths: List[Path], module_name: str = "wf"):
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools import module_loader

with _add_sys_paths(paths):

class FakeModule(ModuleType):
Expand Down Expand Up @@ -76,20 +77,15 @@ def __new__(*args, **kwargs):
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
try:
return real_import(
name,
globals=globals,
locals=locals,
fromlist=fromlist,
level=level,
name, globals=globals, locals=locals, fromlist=fromlist, level=level
)
except (ModuleNotFoundError, AttributeError) as e:
except (ModuleNotFoundError, AttributeError):
return FakeModule(name)

# Temporary ctx tells lytekit to skip local execution when
# inspecting objects
fap = FileAccessProvider(
local_sandbox_dir=tempfile.mkdtemp(prefix="foo"),
raw_output_prefix="bar",
local_sandbox_dir=tempfile.mkdtemp(prefix="foo"), raw_output_prefix="bar"
)
tmp_context = FlyteContext(fap, inspect_objects_only=True)

Expand Down Expand Up @@ -201,9 +197,7 @@ def _construct_ssh_client(
raise ConnectionError("unable to create connection to jump host")

sock = gateway_transport.open_channel(
kind="direct-tcpip",
dest_addr=(remote_conn_info.ip, 22),
src_addr=("", 0),
kind="direct-tcpip", dest_addr=(remote_conn_info.ip, 22), src_addr=("", 0)
)
else:
sock = None
Expand All @@ -214,10 +208,7 @@ def _construct_ssh_client(
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
ssh.connect(
remote_conn_info.ip,
username=remote_conn_info.username,
sock=sock,
pkey=pkey,
remote_conn_info.ip, username=remote_conn_info.username, sock=sock, pkey=pkey
)

transport = ssh.get_transport()
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.