Skip to content

Commit

Permalink
custom run streaming fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbenav committed Dec 12, 2024
1 parent ab0f524 commit 6a17ff7
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 18 deletions.
3 changes: 3 additions & 0 deletions clientai/agent/core/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,9 @@ def execute_step(
raise StepError("No agent context available for step execution")

logger.info(f"Executing step '{step.name}'")
if stream is None:
stream = getattr(step, "stream", False)

logger.debug(
f"Step configuration: use_tools={step.use_tools}, "
f"send_to_llm={step.send_to_llm}, "
Expand Down
120 changes: 102 additions & 18 deletions clientai/agent/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,44 @@ def process_step(self, analysis: str) -> str:
f"Failed to register class steps: {str(e)}"
) from e

def _execute_custom_run(self, input_data: Any) -> Any:
"""Execute the custom run method if defined."""
def _execute_custom_run(
self,
agent: Any,
input_data: Any,
engine: StepExecutionProtocol,
stream_override: Optional[bool] = None,
) -> Any:
"""Execute the custom run method with proper stream handling."""
try:
logger.info("Using custom run method")
if self._custom_run is None:
raise WorkflowError("Custom run method is None")
return self._custom_run(input_data)

try:
if stream_override is not None:
original_steps = {}
for step in self._steps.values():
original_steps[step.name] = getattr(
step, "stream", False
)
object.__setattr__(step, "stream", stream_override)

result = self._custom_run(input_data)
return result

finally:
if (
stream_override is not None
and "original_steps" in locals()
):
for step in self._steps.values():
object.__setattr__(
step, "stream", original_steps[step.name]
)

except Exception as e:
raise WorkflowError(f"Custom run method failed: {str(e)}")
logger.error(f"Custom run method failed: {e}")
raise WorkflowError(f"Custom run method failed: {str(e)}") from e

def _initialize_execution(self, agent: Any, input_data: Any) -> None:
"""Process a custom run method if found.
Expand Down Expand Up @@ -401,6 +430,30 @@ def _handle_step_result(
return result
return None

def _get_step_stream_setting(
self,
step: Step,
stream_override: Optional[bool],
is_intermediate_step: bool,
) -> bool:
"""Determine the appropriate streaming setting for a step.
Args:
step: The step being executed
stream_override: Optional streaming override from run()
is_intermediate_step: Whether this is an intermediate step
Returns:
bool: The determined streaming setting
"""
if stream_override is True and is_intermediate_step:
return False

if stream_override is not None:
return stream_override

return getattr(step, "stream", False)

def execute(
self,
agent: Any,
Expand Down Expand Up @@ -439,33 +492,36 @@ def execute(
agent.context.set_input(input_data)
agent.context.increment_iteration()

engine._current_agent = agent # type: ignore

try:
if self._custom_run:
return self._execute_custom_run(input_data)
return self._execute_custom_run(
agent=agent,
input_data=input_data,
engine=engine,
stream_override=stream_override,
)

last_result = input_data
steps = list(self._steps.values())

for step in self._steps.values():
for step in steps[:-1]:
try:
logger.info(
f"Executing step: {step.name} ({step.step_type})"
)

current_stream = (
stream_override
if stream_override is not None
else getattr(step, "stream", False)
)

param_count = self._get_step_parameters(step)
available_results = len(agent.context.last_results) + 1

self._validate_parameter_count(
step, param_count, available_results
)

current_stream = self._get_step_stream_setting(
step=step,
stream_override=stream_override,
is_intermediate_step=True,
)

result = engine.execute_step(
step,
last_result,
Expand All @@ -479,7 +535,6 @@ def execute(
last_result = step_result

logger.debug(f"Step {step.name} completed")

except (StepError, ValueError) as e:
logger.error(f"Error in step '{step.name}': {e}")
if step.config.required:
Expand All @@ -490,11 +545,40 @@ def execute(
)
continue

if steps:
final_step = steps[-1]

param_count = self._get_step_parameters(final_step)
available_results = len(agent.context.last_results) + 1
self._validate_parameter_count(
final_step, param_count, available_results
)

current_stream = self._get_step_stream_setting(
step=final_step,
stream_override=stream_override,
is_intermediate_step=False,
)

result = engine.execute_step(
final_step, last_result, stream=current_stream
)

if not current_stream:
self._handle_step_result(final_step, result, agent)

return result

logger.info("Workflow execution completed")
return last_result

finally:
engine._current_agent = None # type: ignore
except (StepError, WorkflowError, ValueError):
raise
except Exception as e:
logger.error(f"Unexpected workflow execution error: {e}")
raise WorkflowError(
f"Unexpected workflow execution error: {str(e)}"
) from e

except (StepError, WorkflowError, ValueError):
raise
Expand Down

0 comments on commit 6a17ff7

Please sign in to comment.