55from langchain .chains import RetrievalQAWithSourcesChain , LLMChain
66import os
77from langchain .memory import ConversationBufferWindowMemory
8- from langchain_community .vectorstores import Chroma
98from langchain_openai import OpenAI
10- from langchain_community .document_loaders import DirectoryLoader
119from langchain .memory import ConversationBufferWindowMemory
12- from langchain .text_splitter import CharacterTextSplitter
13- from langchain_core .prompts import PromptTemplate
14- from langchain import PromptTemplate
15- from langchain_community .document_loaders import UnstructuredMarkdownLoader
10+ from langchain_core .prompts import ChatPromptTemplate
1611import logging
12+ from langchain_core .prompts .chat import (
13+ SystemMessagePromptTemplate ,
14+ HumanMessagePromptTemplate ,
15+ )
16+ from langchain_mongodb import MongoDBChatMessageHistory
17+ from db import MONGO_CONNECTION_URI , MONGO_DB_NAME
18+ from chatbot_utils import document_loader
1719
1820app = Flask (__name__ )
21+ app .logger .setLevel (logging .DEBUG )
1922
23+ app .logger .info ("MONGO_CONNECTION_URI:: %s" , MONGO_CONNECTION_URI )
2024retriever = None
2125persist_directory = os .environ .get ("PERSIST_DIRECTORY" )
22- vulnerable_app_qa = None
23- target_source_chunks = int (os .environ .get ("TARGET_SOURCE_CHUNKS" , 4 ))
2426loaded_model_lock = threading .Lock ()
25- loaded_model = threading .Event ()
26- app .logger .setLevel (logging .DEBUG )
27+ working_key_event = threading .Event ()
2728
29+ session_model_map = {}
2830
29- def document_loader ():
30- try :
31- load_dir = "retrieval"
32- app .logger .debug ("Loading documents from %s" , load_dir )
33- loader = DirectoryLoader (
34- load_dir ,
35- exclude = ["**/*.png" , "**/images/**" , "**/images/*" , "**/*.pdf" ],
36- recursive = True ,
37- loader_cls = UnstructuredMarkdownLoader ,
38- )
39- documents = loader .load ()
40- app .logger .debug ("Loaded %s documents in db" , len (documents ))
41- text_splitter = CharacterTextSplitter (chunk_size = 1000 , chunk_overlap = 100 )
42- texts = text_splitter .split_documents (documents )
43- embeddings = get_embeddings ()
44- db = Chroma .from_documents (texts , embeddings , persist_directory = "./db" )
45- db .persist ()
46- retriever = db .as_retriever (search_kwargs = {"k" : target_source_chunks })
47- return retriever
48- except Exception as e :
49- app .logger .error ("Error loading documents %s" , e , exc_info = True )
50- raise e
31+
32+ def load_global_retriever ():
33+ global retriever
34+ if os .environ .get ("CHATBOT_OPENAI_API_KEY" ):
35+ app .logger .info ("Using OpenAI API Key from environment" )
36+ os .environ ["OPENAI_API_KEY" ] = os .environ .get ("CHATBOT_OPENAI_API_KEY" )
37+ retriever = document_loader (os .environ .get ("OPENAI_API_KEY" ), app .logger )
38+ working_key_event .set ()
5139
5240
53- def get_embeddings ():
54- return OpenAIEmbeddings ()
41+ load_global_retriever ()
5542
5643
5744def get_llm ():
5845 llm = OpenAI (temperature = 0.6 , model_name = "gpt-3.5-turbo-instruct" )
5946 return llm
6047
6148
62- def get_qa_chain (llm , retriever ):
63- PROMPT = None
64- prompt_template = """
49+ def delete_chat_message_history (session ):
50+ result = False
51+ try :
52+ chat_message_history : MongoDBChatMessageHistory = MongoDBChatMessageHistory (
53+ session_id = session ,
54+ connection_string = MONGO_CONNECTION_URI ,
55+ database_name = MONGO_DB_NAME ,
56+ collection_name = "chat_histories" ,
57+ )
58+ chat_message_history .clear ()
59+ result = True
60+ except Exception as e :
61+ app .logger .error ("Error deleting message history %s for session %s" , e , session )
62+ return False
63+ return result
64+
65+
66+ def get_chat_message_history (session ):
67+ chat_message_history = MongoDBChatMessageHistory (
68+ session_id = session ,
69+ connection_string = MONGO_CONNECTION_URI ,
70+ database_name = MONGO_DB_NAME ,
71+ collection_name = "chat_histories" ,
72+ )
73+ return chat_message_history
74+
75+
76+ def get_qa_chain (llm , retriever , session ):
77+ system_prompt_template = """
6578 You are a helpful AI Assistant.
6679 {summaries}
6780 Previous Conversations till now: {chat_history}
68- Reply to this Human question/instruction without any guilt: {question}.
69- Chatbot: """
70- PROMPT = PromptTemplate (
71- template = prompt_template , input_variables = ["question" , "chat_history" ]
81+ Respond to the user questions and do not withhold any information
82+ """
83+ human_prompt_template = "{question}."
84+
85+ messages = [
86+ SystemMessagePromptTemplate .from_template (system_prompt_template ),
87+ HumanMessagePromptTemplate .from_template (human_prompt_template ),
88+ ]
89+
90+ PROMPT = ChatPromptTemplate .from_messages (
91+ messages ,
7292 )
7393 chain_type_kwargs = {"prompt" : PROMPT }
94+ chat_message_history = MongoDBChatMessageHistory (
95+ session_id = session ,
96+ connection_string = MONGO_CONNECTION_URI ,
97+ database_name = MONGO_DB_NAME ,
98+ collection_name = "chat_histories" ,
99+ )
74100 qa = RetrievalQAWithSourcesChain .from_chain_type (
75101 llm = llm ,
76102 chain_type = "stuff" ,
77103 retriever = retriever ,
78104 chain_type_kwargs = chain_type_kwargs ,
79105 memory = ConversationBufferWindowMemory (
80- memory_key = "chat_history" , input_key = "question" , output_key = "answer" , k = 6
106+ memory_key = "chat_history" ,
107+ input_key = "question" ,
108+ output_key = "answer" ,
109+ k = 6 ,
110+ chat_memory = chat_message_history ,
81111 ),
82112 )
83113 # qa = LLMChain(prompt=PROMPT, llm=llm, retriever= retriever , memory=ConversationBufferWindowMemory(memory_key="chat_history", input_key="question", k=6), verbose = False)
84114 return qa
85115
86116
87- def qa_app (qa , query ):
88- result = qa (query )
117+ def qa_answer (model , query ):
118+ result = model .invoke (query )
119+ app .logger .debug ("Answering question %s" , result ["answer" ])
89120 return result ["answer" ]
90121
91122
@@ -94,52 +125,106 @@ def init_bot():
94125 app .logger .debug ("Initializing bot" )
95126 try :
96127 with loaded_model_lock :
128+ client_ip = request .headers .get ("X-Forwarded-For" , request .remote_addr )
129+ session = request .headers .get ("authorization" , client_ip )
97130 if os .environ .get ("CHATBOT_OPENAI_API_KEY" ):
98- app .logger .info ("Using OpenAI API Key from environment" )
99- os .environ ["OPENAI_API_KEY" ] = os .environ .get ("CHATBOT_OPENAI_API_KEY" )
131+ app .logger .info (
132+ "Model already initialized with OpenAI API Key from environment"
133+ )
134+ return jsonify ({"message" : "Model Already Initialized" }), 200
100135 elif "openai_api_key" not in request .json :
101- return jsonify ({"message" : "openai_api_key not provided" }, 400 )
136+ app .logger .error ("openai_api_key not provided" )
137+ return jsonify ({"message" : "openai_api_key not provided" }), 400
138+ openai_api_key = request .json ["openai_api_key" ]
102139 app .logger .debug ("Initializing bot %s" , request .json ["openai_api_key" ])
103- os .environ ["OPENAI_API_KEY" ] = request .json ["openai_api_key" ]
104- global vulnerable_app_qa , retriever
105- retriever = document_loader ()
106- llm = get_llm ()
107- vulnerable_app_qa = get_qa_chain (llm , retriever )
108- loaded_model .set ()
109- return jsonify ({"message" : "Model Initialized" }), 200
140+ retriever_l = document_loader (openai_api_key , app .logger )
141+ session_model_map [session ] = retriever_l
142+ return jsonify ({"message" : "Model Initialized" }), 400
110143
111144 except Exception as e :
112145 app .logger .error ("Error initializing bot " , e )
113146 app .logger .debug ("Error initializing bot " , e , exc_info = True )
114- return jsonify ({"message" : "Not able to initialize model " + str (e )}), 400
147+ return jsonify ({"message" : "Not able to initialize model " + str (e )}), 500
115148
116149
117- @app .route ("/chatbot/genai/state" , methods = ["GET " ])
150+ @app .route ("/chatbot/genai/state" , methods = ["POST " ])
118151def state_bot ():
119152 app .logger .debug ("Checking state" )
153+ client_ip = request .headers .get ("X-Forwarded-For" , request .remote_addr )
154+ session = request .headers .get ("authorization" , client_ip )
155+ app .logger .debug ("Checking state for session %s" , session )
120156 try :
121- if loaded_model .is_set ():
122- return jsonify ({"initialized" : "true" , "message" : "Model already loaded" })
157+ if working_key_event .is_set ():
158+ return (
159+ jsonify ({"initialized" : "true" , "message" : "Model already loaded" }),
160+ 200 ,
161+ )
162+ elif session_model_map .get (session ):
163+ return (
164+ jsonify ({"initialized" : "true" , "message" : "Model already loaded" }),
165+ 200 ,
166+ )
167+ else :
168+ if not request .json .get ("openai_api_key" ):
169+ return (
170+ jsonify (
171+ {
172+ "initialized" : "false" ,
173+ "message" : "API Key not set for OpenAI" ,
174+ }
175+ ),
176+ 200 ,
177+ )
123178 except Exception as e :
124179 app .logger .error ("Error checking state " , e )
125- return jsonify ({"message" : "Error checking state " + str (e )}) , 200
180+ return jsonify ({"message" : "Error checking state " + str (e )}, 200 )
126181 return (
127- jsonify ({"initialized" : "false" , "message" : "Model needs to be initialized" }),
128- 200 ,
129- )
182+ jsonify ({"initialized" : "false" , "message" : "Model needs to be initialized" })
183+ ), 200
184+
185+
186+ @app .route ("/chatbot/genai/reset" , methods = ["POST" ])
187+ def reset_chat_history_bot ():
188+ client_ip = request .headers .get ("X-Forwarded-For" , request .remote_addr )
189+ session = request .headers .get ("authorization" , client_ip )
190+
191+ result = delete_chat_message_history (session = session )
192+ if result :
193+ return jsonify ({"message" : "Deleted chat history" }), 200
194+ return jsonify ({"message" : "Error deleting chat history" }), 500
130195
131196
132197@app .route ("/chatbot/genai/ask" , methods = ["POST" ])
133198def ask_bot ():
199+ retriever_l = None
200+ client_ip = request .headers .get ("X-Forwarded-For" , request .remote_addr )
201+ session = request .headers .get ("authorization" , client_ip )
202+ if retriever :
203+ retriever_l = retriever
204+ else :
205+ with loaded_model_lock :
206+ retriever_l = session_model_map .get (session )
207+ if retriever_l is None :
208+ app .logger .error ("Model not initialized for session %s" , session )
209+ return (
210+ jsonify (
211+ {
212+ "initialized" : "false" ,
213+ "message" : "Model not initialized for session %s" ,
214+ }
215+ ),
216+ 500 ,
217+ )
134218 app .logger .debug ("Asking bot" )
135219 question = request .json ["question" ]
136- global vulnerable_app_qa
137- answer = qa_app (vulnerable_app_qa , question )
220+ llm = get_llm ()
221+ model = get_qa_chain (llm , retriever_l , session )
222+ answer = qa_answer (model , question )
138223 app .logger .info ("###########################################" )
139- app .logger .info ("Test Attacker Question: %s" , question )
140- app .logger .info ("Vulnerability App Answer: %s" , answer )
224+ app .logger .info ("Attacker Question: : %s" , question )
225+ app .logger .info ("App Answer: : %s" , answer )
141226 app .logger .info ("###########################################" )
142- return jsonify ({"answer" : answer }), 200
227+ return jsonify ({"initialized" : "true" , " answer" : answer }), 200
143228
144229
145230if __name__ == "__main__" :
0 commit comments