Skip to content

feat(BA-1045): Add Action Tests for Session #4265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
1 change: 1 addition & 0 deletions changes/4265.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Action Tests for `Session`.
7 changes: 4 additions & 3 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,8 @@ async def get_commit_status(request: web.Request, params: Mapping[str, Any]) ->
owner_access_key=owner_access_key,
)
)
return web.json_response(result.result, status=HTTPStatus.OK)
resp = result.commit_info.asdict()
return web.json_response(resp, status=HTTPStatus.OK)


@server_status_required(ALL_ALLOWED)
Expand All @@ -796,7 +797,7 @@ async def get_abusing_report(request: web.Request, params: Mapping[str, Any]) ->
owner_access_key=owner_access_key,
)
)
return web.json_response(result.result or {}, status=HTTPStatus.OK)
return web.json_response(result.abuse_report or {}, status=HTTPStatus.OK)


@server_status_required(ALL_ALLOWED)
Expand Down Expand Up @@ -1156,7 +1157,7 @@ async def get_info(request: web.Request) -> web.Response:
except BackendError:
log.exception("GET_INFO: exception")
raise
return web.json_response(result.result, status=HTTPStatus.OK)
return web.json_response(result.session_info.asdict(), status=HTTPStatus.OK)


@server_status_required(READ_ALLOWED)
Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/data/kernel/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources(name="src")
101 changes: 101 additions & 0 deletions src/ai/backend/manager/data/kernel/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Optional

from ai.backend.common.types import (
ClusterMode,
ResourceSlot,
SessionResult,
SessionTypes,
VFolderMount,
)

if TYPE_CHECKING:
from ai.backend.manager.models.kernel import KernelStatus


@dataclass
class KernelData:
# --- identity & session ---
id: uuid.UUID
session_id: uuid.UUID
session_creation_id: Optional[str]
session_name: Optional[str]
session_type: SessionTypes

# --- cluster info ---
cluster_mode: ClusterMode
cluster_size: int
cluster_role: str
cluster_idx: int
local_rank: int
cluster_hostname: str

# --- uid / gid ---
uid: Optional[int]
main_gid: Optional[int]
gids: Optional[list[int]]

# --- ownership / auth ---
scaling_group: Optional[str]
agent: Optional[str]
agent_addr: Optional[str]
domain_name: str
group_id: uuid.UUID
user_uuid: uuid.UUID
access_key: Optional[str]

# --- image & registry ---
image: Optional[str]
architecture: str
registry: Optional[str]
tag: Optional[str]
container_id: Optional[str]

# --- resources ---
occupied_slots: ResourceSlot
requested_slots: ResourceSlot
occupied_shares: dict
environ: Optional[list[str]]
mounts: Optional[list[str]]
mount_map: dict
vfolder_mounts: Optional[list[VFolderMount]]
attached_devices: dict
resource_opts: dict
bootstrap_script: Optional[str]

# --- networking ---
kernel_host: Optional[str]
repl_in_port: int
repl_out_port: int
stdin_port: int
stdout_port: int
service_ports: Optional[dict]
preopen_ports: Optional[list[int]]
use_host_network: bool

# --- lifecycle timestamps ---
created_at: datetime
terminated_at: Optional[datetime]
starts_at: Optional[datetime]

# --- runtime status ---
status: "KernelStatus"
status_changed: Optional[datetime]
status_info: Optional[str]
status_data: Optional[dict]
status_history: Optional[dict]

# --- callbacks & commands ---
callback_url: Optional[str]
startup_command: Optional[str]

# --- result & logs ---
result: SessionResult
internal_data: Optional[dict]
container_log: Optional[bytes]

# --- metrics ---
num_queries: int
last_stat: Optional[dict]
8 changes: 6 additions & 2 deletions src/ai/backend/manager/data/session/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ai.backend.common.data.vfolder.types import VFolderMountData
from ai.backend.common.types import (
AccessKey,
ClusterMode,
SessionResult,
SessionTypes,
Expand Down Expand Up @@ -32,9 +33,10 @@ class SessionData:
created_at: datetime
status: "SessionStatus"
result: SessionResult
num_queries: int
creation_id: Optional[str]
name: Optional[str]
access_key: Optional[str]
access_key: Optional[AccessKey]
agent_ids: Optional[list[str]]
images: Optional[list[str]]
tag: Optional[str]
Expand All @@ -52,7 +54,9 @@ class SessionData:
status_history: Optional[dict[str, Any]]
callback_url: Optional[str]
startup_command: Optional[str]
num_queries: Optional[int]
last_stat: Optional[dict[str, Any]]
network_type: Optional[NetworkType]
network_id: Optional[str]

# Loaded from relationship
service_ports: Optional[str]
21 changes: 18 additions & 3 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import yarl
from aiodataloader import DataLoader
from aiotools import apartial
from dateutil.parser import isoparse
from graphene.types import Scalar
from graphene.types.scalars import MAX_INT, MIN_INT
from graphql import Undefined
Expand Down Expand Up @@ -1355,6 +1356,13 @@ async def populate_fixture(
async with engine.begin() as conn:
# Apply typedecorator manually for required columns
for col in table.columns:
if isinstance(col.type, sa.sql.sqltypes.DateTime):
for row in rows:
if col.name in row:
if row[col.name] is not None:
row[col.name] = isoparse(row[col.name])
else:
row[col.name] = None
if isinstance(col.type, EnumType):
for row in rows:
if col.name in row:
Expand All @@ -1363,12 +1371,19 @@ async def populate_fixture(
for row in rows:
if col.name in row:
row[col.name] = col.type._enum_cls(row[col.name])
elif isinstance(
col.type, (StructuredJSONObjectColumn, StructuredJSONObjectListColumn)
):
elif isinstance(col.type, (StructuredJSONObjectColumn)):
for row in rows:
if col.name in row:
row[col.name] = col.type._schema.from_json(row[col.name])
elif isinstance(col.type, (StructuredJSONObjectListColumn)):
for row in rows:
if col.name in row and row[col.name] is not None:
row[col.name] = [
item
if isinstance(item, col.type._schema)
else col.type._schema.from_json(item)
for item in row[col.name]
]

match op_mode:
case FixtureOpModes.INSERT:
Expand Down
62 changes: 61 additions & 1 deletion src/ai/backend/manager/models/gql_models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
SessionResult,
VFolderMount,
)
from ai.backend.manager.data.session.types import SessionData
from ai.backend.manager.defs import DEFAULT_ROLE
from ai.backend.manager.idle import ReportInfo
from ai.backend.manager.models.kernel import KernelRow
Expand Down Expand Up @@ -343,6 +344,65 @@ def from_row(
result.permissions = [] if permissions is None else permissions
return result

@classmethod
def from_dataclass(
cls,
ctx: GraphQueryContext,
session_data: SessionData,
*,
permissions: Optional[Iterable[ComputeSessionPermission]] = None,
) -> Self:
status_history = session_data.status_history or {}
raw_scheduled_at = status_history.get(SessionStatus.SCHEDULED.name)
if not session_data.vfolder_mounts:
vfolder_mounts = []
else:
vfolder_mounts = [vf.vfid.folder_id for vf in session_data.vfolder_mounts]

result = cls(
# identity
id=session_data.id, # auto-converted to Relay global ID
row_id=session_data.id,
tag=session_data.tag,
name=session_data.name,
type=session_data.session_type,
cluster_template=None,
cluster_mode=session_data.cluster_mode,
cluster_size=session_data.cluster_size,
priority=session_data.priority,
# ownership
domain_name=session_data.domain_name,
project_id=session_data.group_id,
user_id=session_data.user_uuid,
access_key=session_data.access_key,
# status
status=session_data.status.name,
# status_changed=row.status_changed, # FIXME: generated attribute
status_info=session_data.status_info,
status_data=session_data.status_data,
status_history=status_history,
created_at=session_data.created_at,
starts_at=session_data.starts_at,
terminated_at=session_data.terminated_at,
scheduled_at=datetime.fromisoformat(raw_scheduled_at)
if raw_scheduled_at is not None
else None,
startup_command=session_data.startup_command,
result=session_data.result.name,
# resources
agent_ids=session_data.agent_ids,
scaling_group=session_data.scaling_group_name,
vfolder_mounts=vfolder_mounts,
occupied_slots=session_data.occupying_slots,
requested_slots=session_data.requested_slots,
image_references=session_data.images,
service_ports=session_data.service_ports,
# statistics
num_queries=session_data.num_queries,
)
result.permissions = [] if permissions is None else permissions
return result

async def resolve_idle_checks(self, info: graphene.ResolveInfo) -> dict[str, Any] | None:
graph_ctx: GraphQueryContext = info.context
loader = graph_ctx.dataloader_manager.get_loader_by_func(
Expand Down Expand Up @@ -708,7 +768,7 @@ async def mutate_and_get_payload(
)

return ModifyComputeSession(
ComputeSessionNode.from_row(graph_ctx, result.session_row),
ComputeSessionNode.from_dataclass(graph_ctx, result.session_data),
input.get("client_mutation_id"),
)

Expand Down
68 changes: 68 additions & 0 deletions src/ai/backend/manager/models/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
VFolderMount,
)
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.data.kernel.types import KernelData

from ..api.exceptions import (
BackendError,
Expand Down Expand Up @@ -542,6 +543,73 @@ class KernelRow(Base):
group_row = relationship("GroupRow", back_populates="kernels")
user_row = relationship("UserRow", back_populates="kernels")

@classmethod
def from_dataclass(cls, data: KernelData) -> KernelRow:
raise NotImplementedError("KernelRow.from_dataclass() is not implemented.")

def to_dataclass(self) -> KernelData:
return KernelData(
id=self.id,
session_id=self.session_id,
session_creation_id=self.session_creation_id,
session_name=self.session_name,
session_type=self.session_type,
cluster_mode=self.cluster_mode,
cluster_size=self.cluster_size,
cluster_role=self.cluster_role,
cluster_idx=self.cluster_idx,
local_rank=self.local_rank,
cluster_hostname=self.cluster_hostname,
uid=self.uid,
main_gid=self.main_gid,
gids=self.gids,
scaling_group=self.scaling_group,
agent=self.agent_row.to_dataclass() if getattr(self, "agent_row", None) else None,
agent_addr=self.agent_addr,
domain_name=self.domain_name,
group_id=self.group_id,
user_uuid=self.user_uuid,
access_key=self.access_key,
image=self.image,
architecture=self.architecture,
registry=self.registry,
tag=self.tag,
container_id=self.container_id,
occupied_slots=self.occupied_slots,
requested_slots=self.requested_slots,
occupied_shares=self.occupied_shares,
environ=self.environ,
mounts=self.mounts,
mount_map=self.mount_map,
vfolder_mounts=self.vfolder_mounts,
attached_devices=self.attached_devices,
resource_opts=self.resource_opts,
bootstrap_script=self.bootstrap_script,
kernel_host=self.kernel_host or self.agent_addr,
repl_in_port=self.repl_in_port,
repl_out_port=self.repl_out_port,
stdin_port=self.stdin_port,
stdout_port=self.stdout_port,
service_ports=self.service_ports,
preopen_ports=self.preopen_ports,
use_host_network=self.use_host_network,
created_at=self.created_at,
terminated_at=self.terminated_at,
starts_at=self.starts_at,
status=self.status,
status_changed=self.status_changed,
status_info=self.status_info,
status_data=self.status_data,
status_history=self.status_history,
callback_url=str(self.callback_url) if self.callback_url else None,
startup_command=self.startup_command,
result=self.result,
internal_data=self.internal_data,
container_log=self.container_log,
num_queries=self.num_queries,
last_stat=self.last_stat,
)

@property
def image_ref(self) -> ImageRef | None:
return self.image_row.image_ref if self.image_row else None
Expand Down
Loading
Loading