diff --git a/rapyuta_io/clients/device.py b/rapyuta_io/clients/device.py index 5fb026dc..e6f2fcb1 100644 --- a/rapyuta_io/clients/device.py +++ b/rapyuta_io/clients/device.py @@ -4,6 +4,7 @@ import subprocess from six.moves.urllib.parse import urlencode +from time import sleep import enum import requests @@ -567,7 +568,44 @@ def execute_command(self, command, retry_limit=0): if response.status_code == requests.codes.BAD_REQUEST: raise ParameterMissingException(get_error(response.text)) execution_result = get_api_response_data(response) - return execution_result[self.uuid] + if not command.bg: + return execution_result[self.uuid] + jid = execution_result.get('jid') + if not jid: + raise ValueError("Job ID not found in the response") + return self.fetch_command_result(jid, [self.uuid], timeout=command.timeout) + + def fetch_command_result(self, jid: str, deviceids: list, timeout: int): + """ + Fetch the result of the command execution using the job ID (jid) and the first device ID from the list. + Args: + jobid (str): The job ID of the executed command. + deviceids (list): A list of device IDs on which the command was executed. + timeout (int): The maximum time to wait for the result (in seconds). Default is 300 seconds. + Returns: + dict: The result of the command execution. + Raises: + TimeoutError: If the result is not available within the timeout period. + APIError: If the API returns an error. + """ + + if not deviceids or not isinstance(deviceids, list): + raise ValueError("Device IDs must be provided as a non-empty list.") + url = self._device_api_host + DEVICE_COMMAND_API_PATH + "jobid" + payload = { + "jid": jid, + "device_id": deviceids[0] + } + total_time_waited = 0 + wait_interval = 10 + while total_time_waited < timeout: + response = self._execute_api(url, HttpMethod.POST, payload) + if response.status_code == requests.codes.OK: + result = get_api_response_data(response) + return result[deviceids[0]] + sleep(wait_interval) + total_time_waited += wait_interval + raise TimeoutError(f"Command result not available after {timeout} seconds") def get_config_variables(self): """ diff --git a/rapyuta_io/clients/model.py b/rapyuta_io/clients/model.py index 14e6f286..1af941fc 100644 --- a/rapyuta_io/clients/model.py +++ b/rapyuta_io/clients/model.py @@ -61,7 +61,7 @@ class Command(ObjDict): """ - def __init__(self, cmd, shell=None, env=None, bg=False, runas=None, pwd=None, cwd=None): + def __init__(self, cmd, shell=None, env=None, bg=False, runas=None, pwd=None, cwd=None, timeout=300): super(ObjDict, self).__init__() if env is None: env = dict() @@ -73,6 +73,7 @@ def __init__(self, cmd, shell=None, env=None, bg=False, runas=None, pwd=None, cw self.cwd = pwd if cwd is not None: self.cwd = cwd + self.timeout = timeout self.validate() def validate(self): @@ -93,6 +94,8 @@ def validate(self): raise InvalidCommandException('Invalid environment variables') return raise InvalidCommandException('Invalid environment variables') + if self.timeout <= 0: + raise InvalidCommandException("Invalid timeout value") def to_json(self): # TODO: we need to rewrite this function.