diff --git a/helpers/fasta2a_server.py b/helpers/fasta2a_server.py index 89afade1ca..5b61045ddb 100644 --- a/helpers/fasta2a_server.py +++ b/helpers/fasta2a_server.py @@ -1,5 +1,6 @@ # noqa: D401 (docstrings) – internal helper import asyncio +import os import uuid import atexit from typing import Any, List @@ -59,6 +60,92 @@ async def update_task(self, **kwargs): Message = Artifact = AgentProvider = Skill = Any # type: ignore _PRINTER = PrintStyle(italic=True, font_color="purple", padding=False) +FAILURE_REASON_MAX_CHARS = 300 +A2A_TASK_RESULT_TIMEOUT_ENV = "A2A_TASK_RESULT_TIMEOUT_SECONDS" +DEFAULT_TASK_RESULT_TIMEOUT_SECONDS = 30.0 +MAX_TASK_RESULT_TIMEOUT_SECONDS = 120.0 + + +def _task_result_timeout_seconds() -> float: + try: + configured = float( + os.getenv( + A2A_TASK_RESULT_TIMEOUT_ENV, + str(DEFAULT_TASK_RESULT_TIMEOUT_SECONDS), + ) + or str(DEFAULT_TASK_RESULT_TIMEOUT_SECONDS) + ) + except ValueError: + configured = DEFAULT_TASK_RESULT_TIMEOUT_SECONDS + return max(1.0, min(MAX_TASK_RESULT_TIMEOUT_SECONDS, configured)) + + +def _sanitize_failure_reason(reason: object) -> str: + text = " ".join(str(reason).split()) + if len(text) > FAILURE_REASON_MAX_CHARS: + text = text[:FAILURE_REASON_MAX_CHARS].rstrip() + "..." + return text or "unknown error" + + +def _build_failure_message(reason: str) -> Message: # type: ignore + return { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': f"Agent Zero A2A task failed: {reason}"}], + 'kind': 'message', + 'message_id': str(uuid.uuid4()) + } + + +def _build_tool_output_artifact(text: str, tool_name: str) -> Artifact: # type: ignore + return { + 'artifact_id': str(uuid.uuid4()), + 'name': 'captured_tool_output', + 'description': f"Captured output from {tool_name}", + 'parts': [{'kind': 'text', 'text': text}], + 'metadata': { + 'source': 'a2a_finalization_timeout_fallback', + 'tool_name': tool_name, + }, + } + + +def _cleanup_context(context: AgentContext | None, task_id: str, outcome: str) -> None: + if not context: + return + context.reset() + AgentContext.remove(context.id) + remove_chat(context.id) + _PRINTER.print(f"[A2A] Cleaned up {outcome} context {context.id} for task {task_id}") + + +def _get_history_output(context: AgentContext | None) -> list[Any]: + if not context: + return [] + try: + history = context.agent0.history + if hasattr(history, "output"): + output = history.output() + return output if isinstance(output, list) else [] + except Exception as e: + _PRINTER.print(f"[A2A] Unable to inspect tool output history: {e}") + return [] + + +def _find_latest_tool_output(context: AgentContext | None) -> tuple[str, str] | None: + for message in reversed(_get_history_output(context)): + if not isinstance(message, dict): + continue + content = message.get("content") + if not isinstance(content, dict): + continue + tool_result = content.get("tool_result") + if not isinstance(tool_result, str) or not tool_result.strip(): + continue + tool_name = content.get("tool_name") + if not isinstance(tool_name, str) or not tool_name.strip(): + tool_name = "tool" + return tool_name, tool_result.strip() + return None class AgentZeroWorker(Worker): # type: ignore[misc] @@ -71,7 +158,9 @@ def __init__(self, broker, storage): async def run_task(self, params: Any) -> None: # params: TaskSendParams """Execute a task by processing the message through Agent Zero.""" context = None + task_id = params.get('id', 'unknown') if isinstance(params, dict) else 'unknown' try: + _PRINTER.print(f"[A2A] Task received: {task_id}") task_id = params['id'] message = params['message'] @@ -101,8 +190,55 @@ async def run_task(self, params: Any) -> None: # params: TaskSendParams ) # Process message through Agent Zero (includes response) + _PRINTER.print(f"[A2A] Task {task_id}: entering context.communicate") task = context.communicate(agent_message) - result_text = await task.result() + _PRINTER.print(f"[A2A] Task {task_id}: context.communicate returned") + timeout_seconds = _task_result_timeout_seconds() + _PRINTER.print( + f"[A2A] Task {task_id}: awaiting task.result() " + f"with timeout {timeout_seconds:g}s" + ) + try: + result_text = await asyncio.wait_for( + task.result(), + timeout=timeout_seconds, + ) + _PRINTER.print(f"[A2A] Task {task_id}: task.result() completed") + except asyncio.TimeoutError: + reason = f"task.result() timed out after {timeout_seconds:g}s" + _PRINTER.print(f"[A2A] Task {task_id}: task.result() exception: {reason}") + tool_output = _find_latest_tool_output(context) + if tool_output: + tool_name, output_text = tool_output + _PRINTER.print(f"[A2A] Task {task_id}: tool output captured from {tool_name}") + _PRINTER.print(f"[A2A] Task {task_id}: final response started but timed out") + _PRINTER.print(f"[A2A] Task {task_id}: artifact fallback used") + await self.storage.update_task( # type: ignore[attr-defined] + task_id=task_id, + state='completed', + new_artifacts=[_build_tool_output_artifact(output_text, tool_name)] + ) + _cleanup_context(context, task_id, "completed with artifact fallback") + return + _PRINTER.print(f"[A2A] Task {task_id}: updating task failed") + await self.storage.update_task( # type: ignore[attr-defined] + task_id=task_id, + state='failed', + new_messages=[_build_failure_message(reason)] + ) + _cleanup_context(context, task_id, "timed out") + return + except Exception as e: + reason = f"{type(e).__name__}: {_sanitize_failure_reason(e)}" + _PRINTER.print(f"[A2A] Task {task_id}: task.result() exception: {reason}") + _PRINTER.print(f"[A2A] Task {task_id}: updating task failed") + await self.storage.update_task( # type: ignore[attr-defined] + task_id=task_id, + state='failed', + new_messages=[_build_failure_message(reason)] + ) + _cleanup_context(context, task_id, "failed") + return # Build A2A message from result response_message: Message = { # type: ignore @@ -112,6 +248,7 @@ async def run_task(self, params: Any) -> None: # params: TaskSendParams 'message_id': str(uuid.uuid4()) } + _PRINTER.print(f"[A2A] Task {task_id}: updating task completed") await self.storage.update_task( # type: ignore[attr-defined] task_id=task_id, state='completed', @@ -119,25 +256,22 @@ async def run_task(self, params: Any) -> None: # params: TaskSendParams ) # Clean up context like non-persistent MCP chats - context.reset() - AgentContext.remove(context.id) - remove_chat(context.id) + _cleanup_context(context, task_id, "completed") _PRINTER.print(f"[A2A] Completed task {task_id} and cleaned up context") except Exception as e: - _PRINTER.print(f"[A2A] Error processing task {params.get('id', 'unknown')}: {e}") + reason = f"{type(e).__name__}: {_sanitize_failure_reason(e)}" + _PRINTER.print(f"[A2A] Error processing task {task_id}: {reason}") + _PRINTER.print(f"[A2A] Task {task_id}: updating task failed") await self.storage.update_task( - task_id=params.get('id', 'unknown'), - state='failed' + task_id=task_id, + state='failed', + new_messages=[_build_failure_message(reason)] ) # Clean up context even on failure to prevent resource leaks - if context: - context.reset() - AgentContext.remove(context.id) - remove_chat(context.id) - _PRINTER.print(f"[A2A] Cleaned up failed context {context.id}") + _cleanup_context(context, task_id, "failed") async def cancel_task(self, params: Any) -> None: # params: TaskIdParams """Cancel a running task.""" diff --git a/tests/test_fasta2a_server_worker.py b/tests/test_fasta2a_server_worker.py new file mode 100644 index 0000000000..20a28fbda0 --- /dev/null +++ b/tests/test_fasta2a_server_worker.py @@ -0,0 +1,198 @@ +import asyncio +import sys +from pathlib import Path + +import pytest + +PROJECT_ROOT = Path(__file__).resolve().parents[1] +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from helpers import fasta2a_server + + +class FakeStorage: + def __init__(self): + self.updates = [] + + async def update_task(self, **kwargs): + self.updates.append(kwargs) + return kwargs + + +class FakeLog: + def log(self, **kwargs): + pass + + +class FakeHistory: + def __init__(self): + self.messages = [] + + def output(self): + return list(self.messages) + + +class FakeAgent: + def __init__(self): + self.history = FakeHistory() + + +class FakeContext: + removed = [] + reset_count = 0 + latest = None + + def __init__(self, cfg, type): + self.id = "ctx-test" + self.log = FakeLog() + self.agent0 = FakeAgent() + FakeContext.latest = self + + def reset(self): + FakeContext.reset_count += 1 + + def communicate(self, message): + raise NotImplementedError + + @staticmethod + def remove(context_id): + FakeContext.removed.append(context_id) + + +class HangingTask: + async def result(self): + await asyncio.Event().wait() + + +class FailingTask: + async def result(self): + raise RuntimeError("boom\nwith details") + + +class CompletedTask: + async def result(self): + return "final response" + + +def _params(): + return { + "id": "task-123", + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "hello"}], + }, + } + + +@pytest.fixture(autouse=True) +def patch_runtime(monkeypatch): + FakeContext.removed = [] + FakeContext.reset_count = 0 + FakeContext.latest = None + monkeypatch.delenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, raising=False) + monkeypatch.setattr(fasta2a_server, "AgentContext", FakeContext) + monkeypatch.setattr(fasta2a_server, "initialize_agent", lambda: object()) + monkeypatch.setattr(fasta2a_server, "remove_chat", lambda context_id: None) + + +@pytest.mark.asyncio +async def test_run_task_timeout_marks_task_failed(monkeypatch): + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "0.01") + storage = FakeStorage() + worker = fasta2a_server.AgentZeroWorker(broker=None, storage=storage) + monkeypatch.setattr(FakeContext, "communicate", lambda self, message: HangingTask()) + + await worker.run_task(_params()) + + assert storage.updates[-1]["task_id"] == "task-123" + assert storage.updates[-1]["state"] == "failed" + text = storage.updates[-1]["new_messages"][0]["parts"][0]["text"] + assert "timed out" in text + assert FakeContext.reset_count == 1 + assert FakeContext.removed == ["ctx-test"] + + +@pytest.mark.asyncio +async def test_run_task_timeout_after_tool_output_completes_with_artifact(monkeypatch): + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "0.01") + storage = FakeStorage() + worker = fasta2a_server.AgentZeroWorker(broker=None, storage=storage) + + def communicate(self, message): + self.agent0.history.messages.append({ + "ai": False, + "content": { + "tool_name": "code_execution_tool", + "tool_result": "connected\n", + }, + }) + return HangingTask() + + monkeypatch.setattr(FakeContext, "communicate", communicate) + + await worker.run_task(_params()) + + assert storage.updates[-1]["task_id"] == "task-123" + assert storage.updates[-1]["state"] == "completed" + artifact = storage.updates[-1]["new_artifacts"][0] + assert artifact["name"] == "captured_tool_output" + assert artifact["metadata"]["tool_name"] == "code_execution_tool" + assert artifact["parts"] == [{"kind": "text", "text": "connected"}] + assert "new_messages" not in storage.updates[-1] + assert FakeContext.reset_count == 1 + assert FakeContext.removed == ["ctx-test"] + + +@pytest.mark.asyncio +async def test_run_task_result_exception_marks_task_failed(monkeypatch): + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "0.01") + storage = FakeStorage() + worker = fasta2a_server.AgentZeroWorker(broker=None, storage=storage) + monkeypatch.setattr(FakeContext, "communicate", lambda self, message: FailingTask()) + + await worker.run_task(_params()) + + assert storage.updates[-1]["task_id"] == "task-123" + assert storage.updates[-1]["state"] == "failed" + text = storage.updates[-1]["new_messages"][0]["parts"][0]["text"] + assert "RuntimeError" in text + assert "\n" not in text + assert FakeContext.reset_count == 1 + assert FakeContext.removed == ["ctx-test"] + + +@pytest.mark.asyncio +async def test_run_task_final_response_completes_normally(monkeypatch): + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "0.01") + storage = FakeStorage() + worker = fasta2a_server.AgentZeroWorker(broker=None, storage=storage) + monkeypatch.setattr(FakeContext, "communicate", lambda self, message: CompletedTask()) + + await worker.run_task(_params()) + + assert storage.updates[-1]["task_id"] == "task-123" + assert storage.updates[-1]["state"] == "completed" + message = storage.updates[-1]["new_messages"][0] + assert message["role"] == "agent" + assert message["parts"] == [{"kind": "text", "text": "final response"}] + assert "new_artifacts" not in storage.updates[-1] + assert FakeContext.reset_count == 1 + assert FakeContext.removed == ["ctx-test"] + + +def test_task_result_timeout_uses_env_and_clamps(monkeypatch): + monkeypatch.delenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, raising=False) + assert fasta2a_server._task_result_timeout_seconds() == 30.0 + + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "120") + assert fasta2a_server._task_result_timeout_seconds() == 120.0 + + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "999") + assert fasta2a_server._task_result_timeout_seconds() == 120.0 + + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "0") + assert fasta2a_server._task_result_timeout_seconds() == 1.0 + + monkeypatch.setenv(fasta2a_server.A2A_TASK_RESULT_TIMEOUT_ENV, "not-a-number") + assert fasta2a_server._task_result_timeout_seconds() == 30.0