-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp.py
More file actions
174 lines (145 loc) · 5.68 KB
/
app.py
File metadata and controls
174 lines (145 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""MammoChat web application.
This module implements the main web interface for the MammoChat application using Streamlit.
It provides an interactive chat interface that allows users to ask questions about breast cancer
and receive reliable information from trusted medical sources.
The application uses a RAG (Retrieval-Augmented Generation) system to ensure accurate
information delivery from reputable sources like BreastCancer.org and Komen.org.
"""
from __future__ import annotations
from typing import Literal, TypedDict
import asyncio
import streamlit as st
import json
import logfire
from supabase import Client
from openai import AsyncOpenAI
import os
from config import config, TRUSTED_SOURCES
try:
# Set OpenAI API key in environment variable
os.environ["OPENAI_API_KEY"] = config.openai_api_key
logfire.info("OpenAI API key configured successfully")
except Exception as e:
st.error(f"Failed to configure OpenAI API key: {str(e)}")
logfire.error("OpenAI API key configuration failed", error=str(e))
from pydantic_ai.messages import (
ModelMessage,
ModelRequest,
ModelResponse,
SystemPromptPart,
UserPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
RetryPromptPart,
ModelMessagesTypeAdapter
)
from agent import chat_agent, SourceDeps
try:
openai_client = AsyncOpenAI() # Will use OPENAI_API_KEY from environment
logfire.info("OpenAI client initialized successfully")
except Exception as e:
st.error(f"Failed to initialize OpenAI client: {str(e)}")
logfire.error("OpenAI client initialization failed", error=str(e))
try:
supabase: Client = Client(
config.supabase_url,
config.supabase_service_key
)
logfire.info("Supabase client initialized successfully")
except Exception as e:
st.error(f"Failed to initialize Supabase client: {str(e)}")
logfire.error("Supabase client initialization failed", error=str(e))
# Enable logging in cloud environment
if 'STREAMLIT_CLOUD' in os.environ:
logfire.configure(send_to_logfire='always')
else:
logfire.configure(send_to_logfire='never')
class ChatMessage(TypedDict):
"""Format of messages sent to the browser/API."""
role: Literal['user', 'model']
timestamp: str
content: str
def display_message_part(part):
"""
Display a single part of a message in the Streamlit UI.
Customize how you display system prompts, user prompts,
tool calls, tool returns, etc.
"""
# system-prompt
if part.part_kind == 'system-prompt':
with st.chat_message("system"):
st.markdown(f"**System**: {part.content}")
# user-prompt
elif part.part_kind == 'user-prompt':
with st.chat_message("user"):
st.markdown(part.content)
# text
elif part.part_kind == 'text':
with st.chat_message("assistant"):
st.markdown(part.content)
async def run_agent_with_streaming(user_input: str):
"""
Run the agent with streaming text for the user_input prompt,
while maintaining the entire conversation in `st.session_state.messages`.
"""
# Prepare dependencies
deps = SourceDeps(
supabase=supabase,
openai_client=openai_client
)
# Run the agent in a stream
async with chat_agent.run_stream(
user_input,
deps=deps,
message_history= st.session_state.messages[:-1], # pass entire conversation so far
) as result:
# We'll gather partial text to show incrementally
partial_text = ""
message_placeholder = st.empty()
# Render partial text as it arrives
async for chunk in result.stream_text(delta=True):
partial_text += chunk
message_placeholder.markdown(partial_text)
# Now that the stream is finished, we have a final result.
# Add new messages from this run, excluding user-prompt messages
filtered_messages = [msg for msg in result.new_messages()
if not (hasattr(msg, 'parts') and
any(part.part_kind == 'user-prompt' for part in msg.parts))]
st.session_state.messages.extend(filtered_messages)
# Add the final response to the messages
st.session_state.messages.append(
ModelResponse(parts=[TextPart(content=partial_text)])
)
async def main():
st.title("Mammo.Chat™ -- Breast Cancer AI")
st.write(f"""
AI-Guided Navigation of Breast Cancer from Trusted Sources:
{', '.join(TRUSTED_SOURCES)}
""")
# Initialize chat history in session state if not present
if "messages" not in st.session_state:
st.session_state.messages = []
# Display all messages from the conversation so far
# Each message is either a ModelRequest or ModelResponse.
# We iterate over their parts to decide how to display them.
for msg in st.session_state.messages:
if isinstance(msg, ModelRequest) or isinstance(msg, ModelResponse):
for part in msg.parts:
display_message_part(part)
# Chat input for the user
user_input = st.chat_input("What questions do you have about breast cancer?")
if user_input:
# We append a new request to the conversation explicitly
st.session_state.messages.append(
ModelRequest(parts=[UserPromptPart(content=user_input)])
)
# Display user prompt in the UI
with st.chat_message("user"):
st.markdown(user_input)
# Display the assistant's partial response while streaming
with st.chat_message("assistant"):
# Actually run the agent now, streaming the text
await run_agent_with_streaming(user_input)
if __name__ == "__main__":
asyncio.run(main())