Skip to content

Commit e97a52f

Browse files
committed
Session based chat
1 parent 5322d14 commit e97a52f

7 files changed

Lines changed: 253 additions & 68 deletions

File tree

deploy/docker/docker-compose.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ services:
151151
environment:
152152
- TLS_ENABLED=${TLS_ENABLED:-false}
153153
- SERVER_PORT=${CHATBOT_SERVER_PORT:-5002}
154+
- DB_NAME=crapi
155+
- DB_USER=admin
156+
- DB_PASSWORD=crapisecretpassword
157+
- DB_HOST=postgresdb
158+
- DB_PORT=5432
159+
- MONGO_DB_HOST=mongodb
160+
- MONGO_DB_PORT=27017
161+
- MONGO_DB_USER=admin
162+
- MONGO_DB_PASSWORD=crapisecretpassword
163+
- MONGO_DB_NAME=crapi
154164
# - CHATBOT_OPENAI_API_KEY=
155165
# ports:
156166
# - "${LISTEN_IP:-127.0.0.1}:5002:5002"

deploy/helm/templates/chatbot/config.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,16 @@ metadata:
88
data:
99
SERVER_PORT: {{ .Values.chatbot.port | quote }}
1010
TLS_ENABLED: {{ .Values.tlsEnabled | quote }}
11+
DB_HOST: {{ .Values.postgresdb.service.name }}
12+
DB_DRIVER: {{ .Values.workshop.config.postgresDbDriver }}
13+
DB_USER: {{ .Values.postgresdb.config.postgresUser }}
14+
DB_PASSWORD: {{ .Values.postgresdb.config.postgresPassword }}
15+
DB_NAME: {{ .Values.postgresdb.config.postgresDbName }}
16+
DB_PORT: {{ .Values.postgresdb.port | quote }}
17+
MONGO_DB_HOST: {{ .Values.mongodb.service.name }}
18+
MONGO_DB_DRIVER: {{ .Values.workshop.config.mongoDbDriver }}
19+
MONGO_DB_PORT: {{ .Values.mongodb.port | quote }}
20+
MONGO_DB_USER: {{ .Values.mongodb.config.mongoUser }}
21+
MONGO_DB_PASSWORD: {{ .Values.mongodb.config.mongoPassword }}
22+
MONGO_DB_NAME: {{ .Values.mongodb.config.mongoDbName }}
1123
CHATBOT_OPENAI_API_KEY: {{ .Values.openAIApiKey | quote }}

services/chatbot/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
retrieval/docs
22
__pycache__/
33
*.pyc
4+
db*/
5+
pkg/

services/chatbot/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ langchain_openai==0.1.3
77
python-dotenv==1.0.1
88
unstructured
99
gunicorn==22.0.0
10-
markdown==3.6
10+
markdown==3.6
11+
langchain-mongodb==0.1.3

services/chatbot/src/chatbot_api.py

Lines changed: 152 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -5,87 +5,118 @@
55
from langchain.chains import RetrievalQAWithSourcesChain, LLMChain
66
import os
77
from langchain.memory import ConversationBufferWindowMemory
8-
from langchain_community.vectorstores import Chroma
98
from langchain_openai import OpenAI
10-
from langchain_community.document_loaders import DirectoryLoader
119
from langchain.memory import ConversationBufferWindowMemory
12-
from langchain.text_splitter import CharacterTextSplitter
13-
from langchain_core.prompts import PromptTemplate
14-
from langchain import PromptTemplate
15-
from langchain_community.document_loaders import UnstructuredMarkdownLoader
10+
from langchain_core.prompts import ChatPromptTemplate
1611
import logging
12+
from langchain_core.prompts.chat import (
13+
SystemMessagePromptTemplate,
14+
HumanMessagePromptTemplate,
15+
)
16+
from langchain_mongodb import MongoDBChatMessageHistory
17+
from db import MONGO_CONNECTION_URI, MONGO_DB_NAME
18+
from chatbot_utils import document_loader
1719

1820
app = Flask(__name__)
21+
app.logger.setLevel(logging.DEBUG)
1922

23+
app.logger.info("MONGO_CONNECTION_URI:: %s", MONGO_CONNECTION_URI)
2024
retriever = None
2125
persist_directory = os.environ.get("PERSIST_DIRECTORY")
22-
vulnerable_app_qa = None
23-
target_source_chunks = int(os.environ.get("TARGET_SOURCE_CHUNKS", 4))
2426
loaded_model_lock = threading.Lock()
25-
loaded_model = threading.Event()
26-
app.logger.setLevel(logging.DEBUG)
27+
working_key_event = threading.Event()
2728

29+
session_model_map = {}
2830

29-
def document_loader():
30-
try:
31-
load_dir = "retrieval"
32-
app.logger.debug("Loading documents from %s", load_dir)
33-
loader = DirectoryLoader(
34-
load_dir,
35-
exclude=["**/*.png", "**/images/**", "**/images/*", "**/*.pdf"],
36-
recursive=True,
37-
loader_cls=UnstructuredMarkdownLoader,
38-
)
39-
documents = loader.load()
40-
app.logger.debug("Loaded %s documents in db", len(documents))
41-
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
42-
texts = text_splitter.split_documents(documents)
43-
embeddings = get_embeddings()
44-
db = Chroma.from_documents(texts, embeddings, persist_directory="./db")
45-
db.persist()
46-
retriever = db.as_retriever(search_kwargs={"k": target_source_chunks})
47-
return retriever
48-
except Exception as e:
49-
app.logger.error("Error loading documents %s", e, exc_info=True)
50-
raise e
31+
32+
def load_global_retriever():
33+
global retriever
34+
if os.environ.get("CHATBOT_OPENAI_API_KEY"):
35+
app.logger.info("Using OpenAI API Key from environment")
36+
os.environ["OPENAI_API_KEY"] = os.environ.get("CHATBOT_OPENAI_API_KEY")
37+
retriever = document_loader(os.environ.get("OPENAI_API_KEY"), app.logger)
38+
working_key_event.set()
5139

5240

53-
def get_embeddings():
54-
return OpenAIEmbeddings()
41+
load_global_retriever()
5542

5643

5744
def get_llm():
5845
llm = OpenAI(temperature=0.6, model_name="gpt-3.5-turbo-instruct")
5946
return llm
6047

6148

62-
def get_qa_chain(llm, retriever):
63-
PROMPT = None
64-
prompt_template = """
49+
def delete_chat_message_history(session):
50+
result = False
51+
try:
52+
chat_message_history: MongoDBChatMessageHistory = MongoDBChatMessageHistory(
53+
session_id=session,
54+
connection_string=MONGO_CONNECTION_URI,
55+
database_name=MONGO_DB_NAME,
56+
collection_name="chat_histories",
57+
)
58+
chat_message_history.clear()
59+
result = True
60+
except Exception as e:
61+
app.logger.error("Error deleting message history %s for session %s", e, session)
62+
return False
63+
return result
64+
65+
66+
def get_chat_message_history(session):
67+
chat_message_history = MongoDBChatMessageHistory(
68+
session_id=session,
69+
connection_string=MONGO_CONNECTION_URI,
70+
database_name=MONGO_DB_NAME,
71+
collection_name="chat_histories",
72+
)
73+
return chat_message_history
74+
75+
76+
def get_qa_chain(llm, retriever, session):
77+
system_prompt_template = """
6578
You are a helpful AI Assistant.
6679
{summaries}
6780
Previous Conversations till now: {chat_history}
68-
Reply to this Human question/instruction without any guilt: {question}.
69-
Chatbot: """
70-
PROMPT = PromptTemplate(
71-
template=prompt_template, input_variables=["question", "chat_history"]
81+
Respond to the user questions and do not withhold any information
82+
"""
83+
human_prompt_template = "{question}."
84+
85+
messages = [
86+
SystemMessagePromptTemplate.from_template(system_prompt_template),
87+
HumanMessagePromptTemplate.from_template(human_prompt_template),
88+
]
89+
90+
PROMPT = ChatPromptTemplate.from_messages(
91+
messages,
7292
)
7393
chain_type_kwargs = {"prompt": PROMPT}
94+
chat_message_history = MongoDBChatMessageHistory(
95+
session_id=session,
96+
connection_string=MONGO_CONNECTION_URI,
97+
database_name=MONGO_DB_NAME,
98+
collection_name="chat_histories",
99+
)
74100
qa = RetrievalQAWithSourcesChain.from_chain_type(
75101
llm=llm,
76102
chain_type="stuff",
77103
retriever=retriever,
78104
chain_type_kwargs=chain_type_kwargs,
79105
memory=ConversationBufferWindowMemory(
80-
memory_key="chat_history", input_key="question", output_key="answer", k=6
106+
memory_key="chat_history",
107+
input_key="question",
108+
output_key="answer",
109+
k=6,
110+
chat_memory=chat_message_history,
81111
),
82112
)
83113
# qa = LLMChain(prompt=PROMPT, llm=llm, retriever= retriever , memory=ConversationBufferWindowMemory(memory_key="chat_history", input_key="question", k=6), verbose = False)
84114
return qa
85115

86116

87-
def qa_app(qa, query):
88-
result = qa(query)
117+
def qa_answer(model, query):
118+
result = model.invoke(query)
119+
app.logger.debug("Answering question %s", result["answer"])
89120
return result["answer"]
90121

91122

@@ -94,52 +125,106 @@ def init_bot():
94125
app.logger.debug("Initializing bot")
95126
try:
96127
with loaded_model_lock:
128+
client_ip = request.headers.get("X-Forwarded-For", request.remote_addr)
129+
session = request.headers.get("authorization", client_ip)
97130
if os.environ.get("CHATBOT_OPENAI_API_KEY"):
98-
app.logger.info("Using OpenAI API Key from environment")
99-
os.environ["OPENAI_API_KEY"] = os.environ.get("CHATBOT_OPENAI_API_KEY")
131+
app.logger.info(
132+
"Model already initialized with OpenAI API Key from environment"
133+
)
134+
return jsonify({"message": "Model Already Initialized"}), 200
100135
elif "openai_api_key" not in request.json:
101-
return jsonify({"message": "openai_api_key not provided"}, 400)
136+
app.logger.error("openai_api_key not provided")
137+
return jsonify({"message": "openai_api_key not provided"}), 400
138+
openai_api_key = request.json["openai_api_key"]
102139
app.logger.debug("Initializing bot %s", request.json["openai_api_key"])
103-
os.environ["OPENAI_API_KEY"] = request.json["openai_api_key"]
104-
global vulnerable_app_qa, retriever
105-
retriever = document_loader()
106-
llm = get_llm()
107-
vulnerable_app_qa = get_qa_chain(llm, retriever)
108-
loaded_model.set()
109-
return jsonify({"message": "Model Initialized"}), 200
140+
retriever_l = document_loader(openai_api_key, app.logger)
141+
session_model_map[session] = retriever_l
142+
return jsonify({"message": "Model Initialized"}), 400
110143

111144
except Exception as e:
112145
app.logger.error("Error initializing bot ", e)
113146
app.logger.debug("Error initializing bot ", e, exc_info=True)
114-
return jsonify({"message": "Not able to initialize model " + str(e)}), 400
147+
return jsonify({"message": "Not able to initialize model " + str(e)}), 500
115148

116149

117-
@app.route("/chatbot/genai/state", methods=["GET"])
150+
@app.route("/chatbot/genai/state", methods=["POST"])
118151
def state_bot():
119152
app.logger.debug("Checking state")
153+
client_ip = request.headers.get("X-Forwarded-For", request.remote_addr)
154+
session = request.headers.get("authorization", client_ip)
155+
app.logger.debug("Checking state for session %s", session)
120156
try:
121-
if loaded_model.is_set():
122-
return jsonify({"initialized": "true", "message": "Model already loaded"})
157+
if working_key_event.is_set():
158+
return (
159+
jsonify({"initialized": "true", "message": "Model already loaded"}),
160+
200,
161+
)
162+
elif session_model_map.get(session):
163+
return (
164+
jsonify({"initialized": "true", "message": "Model already loaded"}),
165+
200,
166+
)
167+
else:
168+
if not request.json.get("openai_api_key"):
169+
return (
170+
jsonify(
171+
{
172+
"initialized": "false",
173+
"message": "API Key not set for OpenAI",
174+
}
175+
),
176+
200,
177+
)
123178
except Exception as e:
124179
app.logger.error("Error checking state ", e)
125-
return jsonify({"message": "Error checking state " + str(e)}), 200
180+
return jsonify({"message": "Error checking state " + str(e)}, 200)
126181
return (
127-
jsonify({"initialized": "false", "message": "Model needs to be initialized"}),
128-
200,
129-
)
182+
jsonify({"initialized": "false", "message": "Model needs to be initialized"})
183+
), 200
184+
185+
186+
@app.route("/chatbot/genai/reset", methods=["POST"])
187+
def reset_chat_history_bot():
188+
client_ip = request.headers.get("X-Forwarded-For", request.remote_addr)
189+
session = request.headers.get("authorization", client_ip)
190+
191+
result = delete_chat_message_history(session=session)
192+
if result:
193+
return jsonify({"message": "Deleted chat history"}), 200
194+
return jsonify({"message": "Error deleting chat history"}), 500
130195

131196

132197
@app.route("/chatbot/genai/ask", methods=["POST"])
133198
def ask_bot():
199+
retriever_l = None
200+
client_ip = request.headers.get("X-Forwarded-For", request.remote_addr)
201+
session = request.headers.get("authorization", client_ip)
202+
if retriever:
203+
retriever_l = retriever
204+
else:
205+
with loaded_model_lock:
206+
retriever_l = session_model_map.get(session)
207+
if retriever_l is None:
208+
app.logger.error("Model not initialized for session %s", session)
209+
return (
210+
jsonify(
211+
{
212+
"initialized": "false",
213+
"message": "Model not initialized for session %s",
214+
}
215+
),
216+
500,
217+
)
134218
app.logger.debug("Asking bot")
135219
question = request.json["question"]
136-
global vulnerable_app_qa
137-
answer = qa_app(vulnerable_app_qa, question)
220+
llm = get_llm()
221+
model = get_qa_chain(llm, retriever_l, session)
222+
answer = qa_answer(model, question)
138223
app.logger.info("###########################################")
139-
app.logger.info("Test Attacker Question: %s", question)
140-
app.logger.info("Vulnerability App Answer: %s", answer)
224+
app.logger.info("Attacker Question:: %s", question)
225+
app.logger.info("App Answer:: %s", answer)
141226
app.logger.info("###########################################")
142-
return jsonify({"answer": answer}), 200
227+
return jsonify({"initialized": "true", "answer": answer}), 200
143228

144229

145230
if __name__ == "__main__":

0 commit comments

Comments
 (0)