Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Move task retry logic to correct branch in stream_logs_by_id #4407

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions sky/jobs/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def start(job_id, dag_yaml, retry_until_up):
"""Start the controller."""
controller_process = None
cancelling = False
task_id = None
try:
_handle_signal(job_id)
# TODO(suquark): In theory, we should make controller process a
Expand All @@ -491,6 +492,7 @@ def start(job_id, dag_yaml, retry_until_up):
except exceptions.ManagedJobUserCancelledError:
dag, _ = _get_dag_and_name(dag_yaml)
task_id, _ = managed_job_state.get_latest_task_id_status(job_id)
assert task_id is not None, job_id
logger.info(
f'Cancelling managed job, job_id: {job_id}, task_id: {task_id}')
managed_job_state.set_cancelling(
Expand Down Expand Up @@ -522,6 +524,7 @@ def start(job_id, dag_yaml, retry_until_up):
logger.info(f'Cluster of managed job {job_id} has been cleaned up.')

if cancelling:
assert task_id is not None, job_id # Since it's set with cancelling
managed_job_state.set_cancelled(
job_id=job_id,
callback_func=managed_job_utils.event_callback_func(
Expand Down
10 changes: 6 additions & 4 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,12 @@ def get_latest_task_id_status(
id_statuses = _get_all_task_ids_statuses(job_id)
if len(id_statuses) == 0:
return None, None
task_id, status = id_statuses[-1]
for task_id, status in id_statuses:
if not status.is_terminal():
break
task_id, status = next(
((tid, st) for tid, st in id_statuses if not st.is_terminal()),
id_statuses[-1],
)
# Unpack the tuple first, or it triggers a Pylint's bug on recognizing
# the return type.
return task_id, status


Expand Down
58 changes: 36 additions & 22 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,42 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
job_statuses = backend.get_job_status(handle, stream_logs=False)
job_status = list(job_statuses.values())[0]
assert job_status is not None, 'No job found.'
assert task_id is not None, job_id

user_code_failure_states = [
job_lib.JobStatus.FAILED, job_lib.JobStatus.FAILED_SETUP
]
andylizf marked this conversation as resolved.
Show resolved Hide resolved
if job_status in user_code_failure_states:
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
# We don't need to wait for the managed job status
# update, as the job is guaranteed to be in terminal
# state afterwards.
break
print()
status_display.update(
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()

def is_managed_job_status_updated():
andylizf marked this conversation as resolved.
Show resolved Hide resolved
"""Check if local managed job status reflects remote
job failure.

Ensures synchronization between remote cluster failure
detection (JobStatus.FAILED) and controller retry logic.
"""
andylizf marked this conversation as resolved.
Show resolved Hide resolved
nonlocal managed_job_status
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(job_id))
return (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING)

while not is_managed_job_status_updated():
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
if job_status != job_lib.JobStatus.CANCELLED:
assert task_id is not None, job_id
if task_id < num_tasks - 1 and follow:
# The log for the current job is finished. We need to
# wait until next job to be started.
Expand All @@ -410,27 +444,7 @@ def stream_logs_by_id(job_id: int, follow: bool = True) -> str:
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
else:
task_specs = managed_job_state.get_task_specs(
job_id, task_id)
if task_specs.get('max_restarts_on_errors', 0) == 0:
# We don't need to wait for the managed job status
# update, as the job is guaranteed to be in terminal
# state afterwards.
break
print()
status_display.update(
ux_utils.spinner_message(
'Waiting for next restart for the failed task'))
status_display.start()
while True:
_, managed_job_status = (
managed_job_state.get_latest_task_id_status(
job_id))
if (managed_job_status !=
managed_job_state.ManagedJobStatus.RUNNING):
break
time.sleep(JOB_STATUS_CHECK_GAP_SECONDS)
continue
break
# The job can be cancelled by the user or the controller (when
# the cluster is partially preempted).
logger.debug(
Expand Down