Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions sdks/python/apache_beam/ml/inference/agent_development_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,6 @@ def run_inference(
for element in batch:
session_id: str = inference_args.get("session_id", str(uuid.uuid4()))

# Ensure a session exists for this invocation
try:
model.session_service.create_session(
app_name=self._app_name,
user_id=user_id,
session_id=session_id,
)
except sessions.SessionExistsError:
# It's okay if the session already exists for shared session IDs.
pass

# Wrap plain strings in a Content object
if isinstance(element, str):
# pyrefly: ignore[bad-instantiation]
Expand All @@ -249,7 +238,8 @@ def run_inference(
message = element

agent_invocations.append(
self._invoke_agent(model, user_id, session_id, message))
self._invoke_agent(
model, user_id, session_id, self._app_name, message))
elements_with_sessions.append(element)

# Run all agent invocations concurrently
Expand All @@ -274,6 +264,7 @@ async def _invoke_agent(
runner: "Runner",
user_id: str,
session_id: str,
app_name: str,
message: genai_Content,
) -> Optional[str]:
"""Drives the ADK event loop and returns the final response text.
Expand All @@ -288,6 +279,17 @@ async def _invoke_agent(
The text of the agent's final response, or ``None`` if the agent
produced no final text response.
"""
# Check for your specific session ID
try:
# Attempt to get the specific session
await runner.session_service.get_session(session_id)
except Exception as e:
await runner.session_service.create_session(
app_name=app_name,
user_id=user_id,
session_id=session_id,
)

async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ async def _async_gen(*args, **kwargs):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
runner.session_service.get_session = mock.AsyncMock(
side_effect=Exception("Session not found"))
runner.session_service.create_session = mock.AsyncMock()
return runner


Expand Down Expand Up @@ -266,6 +269,8 @@ async def _async_gen(*args, **kwargs):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
runner.session_service.get_session = mock.AsyncMock(side_effect=Exception())
runner.session_service.create_session = mock.AsyncMock()

handler = ADKAgentModelHandler(agent=agent)
results = list(handler.run_inference(batch=["hello"], model=runner))
Expand All @@ -287,6 +292,8 @@ async def _async_gen(*args, **kwargs):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
runner.session_service.get_session = mock.AsyncMock(side_effect=Exception())
runner.session_service.create_session = mock.AsyncMock()

handler = ADKAgentModelHandler(agent=agent)
results = list(handler.run_inference(batch=["hello"], model=runner))
Expand All @@ -313,6 +320,8 @@ async def _async_gen(*args, **kwargs):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
runner.session_service.get_session = mock.AsyncMock(side_effect=Exception())
runner.session_service.create_session = mock.AsyncMock()

handler = ADKAgentModelHandler(agent=agent)
results = list(handler.run_inference(batch=["hi"], model=runner))
Expand All @@ -326,7 +335,7 @@ def test_invoke_agent_static_method_directly(self):

result = asyncio.run(
ADKAgentModelHandler._invoke_agent(
runner, "user", "session-1", mock.MagicMock()))
runner, "user", "session-1", "test_app", mock.MagicMock()))
self.assertEqual(result, "direct result")


Expand Down
Loading