-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
137 lines (111 loc) · 4.06 KB
/
main.py
File metadata and controls
137 lines (111 loc) · 4.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import requests
import databases
import sqlalchemy
from fastapi import FastAPI, Depends, Security, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.security.api_key import APIKey
from fastapi.security.api_key import APIKeyHeader
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from random import random
from starlette.status import HTTP_403_FORBIDDEN
from typing import List
from contextlib import asynccontextmanager
GPT_API_TOKEN = os.getenv("GPT_API_TOKEN")
API_KEY = os.getenv("API_KEY")
GPT_MODEL = os.getenv("GPT_MODEL", default="gpt-3.5-turbo")
GPT_PROMPT = os.getenv("GPT_PROMPT", default="Antworte kurz:")
# GPT_TEMPERATURE = float(os.getenv("GPT_TEMPERATURE", default="0.5"))
GPT_MAX_TOKENS = int(os.getenv("GPT_MAX_TOKENS", default="256"))
DATABASE_URL = os.getenv("DATABASE_URL", default="sqlite:///./test.db")
database = databases.Database(DATABASE_URL)
metadata = sqlalchemy.MetaData()
interactions_table = sqlalchemy.Table(
"interactions",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True, autoincrement=True),
sqlalchemy.Column("query", sqlalchemy.String),
sqlalchemy.Column("temperature", sqlalchemy.Float),
sqlalchemy.Column("timestamp", sqlalchemy.Integer),
sqlalchemy.Column("response", sqlalchemy.String),
sqlalchemy.Column("prompt_tokens", sqlalchemy.Integer),
sqlalchemy.Column("completion_tokens", sqlalchemy.Integer),
sqlalchemy.Column("total_tokens", sqlalchemy.Integer),
)
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
@asynccontextmanager
async def lifespan(app: FastAPI):
await database.connect()
yield
await database.disconnect()
app = FastAPI(lifespan=lifespan)
api_key_header = APIKeyHeader(name="x-api-key", auto_error=False)
app.mount("/static", StaticFiles(directory="static"), name="static")
class CompletionRequest(BaseModel):
query: str
class Interaction(BaseModel):
id: int
query: str
temperature: float
timestamp: int
response: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
async def query_gpt_completions_api(gpt_query: str) -> str:
temperature = random() # GPT_TEMPERATURE
response = requests.post(
url="https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {GPT_API_TOKEN}"},
json={
"model": GPT_MODEL,
"messages": [{"role": "user", "content": f"{GPT_PROMPT.strip()} {gpt_query}"}],
"temperature": temperature,
"max_tokens": GPT_MAX_TOKENS,
}
)
data = response.json()
try:
query_response = data["choices"][0]["message"]["content"]
sql_query = interactions_table.insert().values(
query=gpt_query,
temperature=temperature,
timestamp=data["created"],
response=query_response,
prompt_tokens=data["usage"]["prompt_tokens"],
completion_tokens=data["usage"]["completion_tokens"],
total_tokens=data["usage"]["total_tokens"],
)
await database.execute(sql_query)
return query_response
except KeyError:
return data
async def verify_api_key(api_key_header: str = Security(api_key_header)):
if api_key_header == API_KEY:
return api_key_header
else:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Could not validate API KEY"
)
@app.post("/v1/completions")
async def completions(request_body: CompletionRequest, api_key: APIKey = Depends(verify_api_key)):
return await query_gpt_completions_api(request_body.query)
@app.get("/interactions", response_model=List[Interaction])
async def interactions():
query = interactions_table.select()
return await database.fetch_all(query)
@app.get("/", response_class=HTMLResponse)
async def index():
return """
<html>
<head>
<title>Schreibmaschine</title>
<script src='static/main.js'></script>
<link rel='stylesheet' href='static/main.css'>
</head>
<body onload='update_data()'>
</body>
</html>
"""